Commit 7462218e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1'

parents 6ccd3f47 1cec5e62
......@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
......@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul_opt);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul_opt);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul_opt);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
......@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
......@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
......
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
if __name__ == '__main__':
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser.add_argument('--template', type=str, help="Path to template")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer = AutoTokenizer.from_pretrained(args.model)
try:
f = open(args.template,'r')
tokenizer.chat_template = f.read()
except Exception as e:
print('except:',e)
finally:
f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history):
prompt = ""
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\n{model_name}:{response}"
return prompt
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
history.append({"role": "user", "content": query})
new_query = tokenizer.apply_chat_template(history, tokenize=False)
example_input = {
"prompt": new_query,
"stream": False,
"temperature": 0.0,
"request_id": 0,
}
results_generator = engine.generate(
example_input["prompt"],
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
example_input["request_id"]
)
start = 0
end = 0
response = ""
async def process_results():
async for output in results_generator:
global end
global start
global response
print(output.outputs[0].text[start:], end="", flush=True)
length = len(output.outputs[0].text)
start = length
response = output.outputs[0].text
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
......@@ -2,6 +2,5 @@
-r requirements-common.txt
# Dependencies for AMD GPUs
ray == 2.9.1
# ray >= 2.10.0
ray >= 2.10.0
pytest-asyncio
......@@ -18,6 +18,9 @@ from typing import Optional, Union
import subprocess
from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
......@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise RuntimeError("Unable to find version string.")
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_sha(root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str:
vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das1.1.git' + sha[:7]
# abi version
version += "." + get_abi()
if add_git_version:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das.opt1' + sha[:7]
else:
version = 'das.opt1'
# dtk version
if os.getenv("ROCM_PATH"):
......@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "")
rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write("__version__='0.5.0.post1'\n")
file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version))
......
......@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
caplog_vllm,
......@@ -118,7 +119,8 @@ def test_preemption(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
......@@ -176,7 +178,8 @@ def test_swap(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible(
......@@ -220,7 +223,8 @@ def test_swap_infeasible(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible(
vllm_runner,
......
......@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = output.squeeze(0)
......
......@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
......@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
if not is_hip():
scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp()
attn_op = xops.fmha.cutlass.FwOp()
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
......@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
if not is_hip():
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
import contextlib
import functools
from typing import List, Optional, Tuple, Type
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from lmslim import quant_ops
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try:
import vllm._C
except ImportError as e:
......@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
......@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
......@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops
# awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
......@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
......@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.use_flash_attn_auto:
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
......@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len,
True,
self.scale,
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len >= 8000:
out = self.attn_func_triton(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
out = self.attn_func_ck(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
......@@ -439,4 +491,4 @@ def _sdpa_attention(
output[start:end, :, :] = sub_out
start = end
return output
return output
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window=None):
cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
BLOCK = 32 if cap[0] >= 8 else 32
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
......@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
num_warps = 8 if Lk <= 64 else 8
num_warps = 8 if Lk <= 64 else 4
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
......
This diff is collapsed.
......@@ -173,7 +173,7 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
rocm_supported_quantization = ["gptq", "squeezellm","awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......@@ -279,6 +279,12 @@ class ModelConfig:
return self.hf_text_config.hidden_size
def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
......
......@@ -315,4 +315,4 @@ class CustomAllreduce:
self._ptr = 0
def __del__(self):
self.close()
self.close()
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
......@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
......@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
# logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment