Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7462218e
Commit
7462218e
authored
Sep 05, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.0-dtk24.04.1'
parents
6ccd3f47
1cec5e62
Changes
60
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2595 additions
and
327 deletions
+2595
-327
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+65
-0
examples/offline_inference.py
examples/offline_inference.py
+20
-19
examples/offline_streaming_inference_chat_demo.py
examples/offline_streaming_inference_chat_demo.py
+110
-0
examples/template_llama_chat.jinja
examples/template_llama_chat.jinja
+24
-0
requirements-rocm.txt
requirements-rocm.txt
+1
-2
setup.py
setup.py
+12
-22
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+8
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-0
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+119
-115
vllm/_custom_ops.py
vllm/_custom_ops.py
+137
-8
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+68
-16
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
+1308
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+103
-43
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+2
-2
vllm/benchmark_throughput.py
vllm/benchmark_throughput.py
+464
-0
vllm/config.py
vllm/config.py
+7
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+91
-75
vllm/envs.py
vllm/envs.py
+25
-9
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+28
-10
No files found.
csrc/torch_bindings.cpp
View file @
7462218e
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Activation ops
// Activation ops
// Activation function used in SwiGLU.
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
...
@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
// Activation function used in SwiGLU. (opt)
ops
.
def
(
"silu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul_opt"
,
torch
::
kCUDA
,
&
silu_and_mul_opt
);
// Activation function used in GeGLU with `none` approximation. (opt)
ops
.
def
(
"gelu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_and_mul_opt
);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops
.
def
(
"gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul_opt
);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
ops
.
def
(
...
@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops
.
def
(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_opt"
,
torch
::
kCUDA
,
&
rms_norm_opt
);
// In-place fused Add and RMS Normalization. (opt)
ops
.
def
(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_opt"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_opt
);
// Rotary embedding
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
ops
.
def
(
...
@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()"
);
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops
.
impl
(
"rotary_embedding_tgi"
,
torch
::
kCUDA
,
&
rotary_embedding_tgi
);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
// (supports multiple loras).
ops
.
def
(
ops
.
def
(
...
@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"
);
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantization ops
// Quantization ops
#ifndef USE_ROCM
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
// Quantized GEMM for AQLM.
...
...
examples/offline_inference.py
View file @
7462218e
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
if
__name__
==
'__main__'
:
prompts
=
[
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The capital of France is"
,
"The future of AI is"
,
"The future of AI is"
,
]
]
# Create a sampling params object.
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
16
)
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
trust_remote_code
=
True
,
dtype
=
"float16"
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
t
ensor_parallel_size
=
1
,
distributed_executor_backend
=
"ray"
,
dtype
=
"float16"
,
t
rust_remote_code
=
True
,
enforce_eager
=
True
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
examples/offline_streaming_inference_chat_demo.py
0 → 100644
View file @
7462218e
from
vllm.sampling_params
import
SamplingParams
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
import
asyncio
from
vllm.utils
import
FlexibleArgumentParser
from
transformers
import
AutoTokenizer
import
logging
import
argparse
import
sys
vllm_logger
=
logging
.
getLogger
(
"vllm"
)
vllm_logger
.
setLevel
(
logging
.
WARNING
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
"""ArgumentParser that allows both underscore and dash in names."""
def
parse_args
(
self
,
args
=
None
,
namespace
=
None
):
if
args
is
None
:
args
=
sys
.
argv
[
1
:]
# Convert underscores to dashes and vice versa in argument names
processed_args
=
[]
for
arg
in
args
:
if
arg
.
startswith
(
'--'
):
if
'='
in
arg
:
key
,
value
=
arg
.
split
(
'='
,
1
)
key
=
'--'
+
key
[
len
(
'--'
):].
replace
(
'_'
,
'-'
)
processed_args
.
append
(
f
'
{
key
}
=
{
value
}
'
)
else
:
processed_args
.
append
(
'--'
+
arg
[
len
(
'--'
):].
replace
(
'_'
,
'-'
))
else
:
processed_args
.
append
(
arg
)
return
super
().
parse_args
(
processed_args
,
namespace
)
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
'--template'
,
type
=
str
,
help
=
"Path to template"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model
)
try
:
f
=
open
(
args
.
template
,
'r'
)
tokenizer
.
chat_template
=
f
.
read
()
except
Exception
as
e
:
print
(
'except:'
,
e
)
finally
:
f
.
close
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
model_name
=
args
.
model
.
split
(
"/"
)[
-
1
]
if
args
.
model
.
split
(
"/"
)[
-
1
]
!=
""
else
args
.
model
.
split
(
"/"
)[
-
2
]
print
(
f
"欢迎使用
{
model_name
}
模型,输入内容即可进行对话,stop 终止程序"
)
def
build_prompt
(
history
):
prompt
=
""
for
query
,
response
in
history
:
prompt
+=
f
"
\n\n
用户:
{
query
}
"
prompt
+=
f
"
\n\n
{
model_name
}
:
{
response
}
"
return
prompt
history
=
[]
while
True
:
query
=
input
(
"
\n
用户:"
)
if
query
.
strip
()
==
"stop"
:
break
history
.
append
({
"role"
:
"user"
,
"content"
:
query
})
new_query
=
tokenizer
.
apply_chat_template
(
history
,
tokenize
=
False
)
example_input
=
{
"prompt"
:
new_query
,
"stream"
:
False
,
"temperature"
:
0.0
,
"request_id"
:
0
,
}
results_generator
=
engine
.
generate
(
example_input
[
"prompt"
],
SamplingParams
(
temperature
=
example_input
[
"temperature"
],
max_tokens
=
100
),
example_input
[
"request_id"
]
)
start
=
0
end
=
0
response
=
""
async
def
process_results
():
async
for
output
in
results_generator
:
global
end
global
start
global
response
print
(
output
.
outputs
[
0
].
text
[
start
:],
end
=
""
,
flush
=
True
)
length
=
len
(
output
.
outputs
[
0
].
text
)
start
=
length
response
=
output
.
outputs
[
0
].
text
asyncio
.
run
(
process_results
())
history
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
print
()
examples/template_llama_chat.jinja
0 → 100644
View file @
7462218e
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
requirements-rocm.txt
View file @
7462218e
...
@@ -2,6 +2,5 @@
...
@@ -2,6 +2,5 @@
-r requirements-common.txt
-r requirements-common.txt
# Dependencies for AMD GPUs
# Dependencies for AMD GPUs
ray == 2.9.1
ray >= 2.10.0
# ray >= 2.10.0
pytest-asyncio
pytest-asyncio
setup.py
View file @
7462218e
...
@@ -18,6 +18,9 @@ from typing import Optional, Union
...
@@ -18,6 +18,9 @@ from typing import Optional, Union
import
subprocess
import
subprocess
from
pathlib
import
Path
from
pathlib
import
Path
add_git_version
=
False
if
int
(
os
.
environ
.
get
(
'ADD_GIT_VERSION'
,
'0'
))
==
1
:
add_git_version
=
True
def
load_module_from_path
(
module_name
,
path
):
def
load_module_from_path
(
module_name
,
path
):
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
...
@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
...
@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise
RuntimeError
(
"Unable to find version string."
)
raise
RuntimeError
(
"Unable to find version string."
)
def
get_abi
():
try
:
command
=
"echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result
=
subprocess
.
run
(
command
,
shell
=
True
,
capture_output
=
True
,
text
=
True
)
output
=
result
.
stdout
.
strip
()
abi
=
"abi"
+
output
.
split
(
" "
)[
-
1
]
return
abi
except
Exception
:
return
'abiUnknown'
def
get_sha
(
root
:
Union
[
str
,
Path
])
->
str
:
def
get_sha
(
root
:
Union
[
str
,
Path
])
->
str
:
try
:
try
:
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
root
).
decode
(
'ascii'
).
strip
()
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
root
).
decode
(
'ascii'
).
strip
()
except
Exception
:
except
Exception
:
return
'Unknown'
return
'Unknown'
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
vllm_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
vllm_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
vllm_root
,
"vllm"
),
"version.py"
)
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
vllm_root
,
"vllm"
),
"version.py"
)
if
add_git_version
:
if
sha
!=
'Unknown'
:
if
sha
!=
'Unknown'
:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
version
=
'das1.1.git'
+
sha
[:
7
]
version
=
'das.opt1'
+
sha
[:
7
]
else
:
# abi version
version
=
'das.opt1'
version
+=
"."
+
get_abi
()
# dtk version
# dtk version
if
os
.
getenv
(
"ROCM_PATH"
):
if
os
.
getenv
(
"ROCM_PATH"
):
...
@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path
=
os
.
path
.
join
(
rocm_path
,
'.info'
,
"rocm_version"
)
rocm_version_path
=
os
.
path
.
join
(
rocm_path
,
'.info'
,
"rocm_version"
)
with
open
(
rocm_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
with
open
(
rocm_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
lines
=
file
.
readlines
()
lines
=
file
.
readlines
()
rocm_version
=
lines
[
0
]
[:
-
2
]
.
replace
(
"."
,
""
)
rocm_version
=
lines
[
0
].
replace
(
"."
,
""
)
version
+=
".dtk"
+
rocm_version
version
+=
".dtk"
+
rocm_version
# torch version
version
+=
".torch"
+
torch
.
__version__
[:
5
]
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
file
.
write
(
"__version__='0.5.0.post1'
\n
"
)
file
.
write
(
"__version__='0.5.0.post1'
\n
"
)
file
.
write
(
"__dcu_version__='0.5.0.post1+{}'
\n
"
.
format
(
version
))
file
.
write
(
"__dcu_version__='0.5.0.post1+{}'
\n
"
.
format
(
version
))
...
...
tests/basic_correctness/test_preemption.py
View file @
7462218e
...
@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
...
@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
def
test_preemption
(
def
test_preemption
(
caplog_vllm
,
caplog_vllm
,
...
@@ -118,7 +119,8 @@ def test_preemption(
...
@@ -118,7 +119,8 @@ def test_preemption(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
def
test_swap
(
def
test_swap
(
...
@@ -176,7 +178,8 @@ def test_swap(
...
@@ -176,7 +178,8 @@ def test_swap(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
def
test_swap_infeasible
(
def
test_swap_infeasible
(
...
@@ -220,7 +223,8 @@ def test_swap_infeasible(
...
@@ -220,7 +223,8 @@ def test_swap_infeasible(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
def
test_preemption_infeasible
(
def
test_preemption_infeasible
(
vllm_runner
,
vllm_runner
,
...
...
tests/kernels/test_attention.py
View file @
7462218e
...
@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
...
@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
)
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
...
tests/kernels/test_prefix_prefill.py
View file @
7462218e
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
is_hip
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -158,6 +159,7 @@ def test_contexted_kv_attention(
...
@@ -158,6 +159,7 @@ def test_contexted_kv_attention(
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
...
@@ -373,6 +375,8 @@ def test_contexted_kv_attention_alibi(
...
@@ -373,6 +375,8 @@ def test_contexted_kv_attention_alibi(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
...
...
vllm/_custom_ops.py
View file @
7462218e
import
contextlib
import
contextlib
import
functools
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
from
lmslim
import
quant_ops
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
try
:
try
:
import
vllm._C
import
vllm._C
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -58,6 +62,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
...
@@ -58,6 +62,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
silu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul_opt
(
out
,
x
)
def
gelu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_and_mul_opt
(
out
,
x
)
def
gelu_tanh_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_tanh_and_mul_opt
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_fast
(
out
,
x
)
...
@@ -125,6 +141,65 @@ def paged_attention_v2(
...
@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# page attention ops (opt)
def
paged_attention_v1_opt
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# pos encoding ops
# pos encoding ops
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -158,9 +233,30 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
...
@@ -158,9 +233,30 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
def
rms_norm_opt
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
rms_norm_opt
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm_opt
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm_opt
(
input
,
residual
,
weight
,
epsilon
)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# quantization ops
# quantization ops
# awq
# awq
def
GetAWQShareWorkspaceSize
()
->
int
:
return
quant_ops
.
GetAWQShareWorkspaceSize
()
def
GetAWQShareWorkspace
()
->
torch
.
Tensor
:
return
quant_ops
.
GetAWQShareWorkspace
()
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
thy
:
int
)
->
torch
.
Tensor
:
...
@@ -168,23 +264,56 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
...
@@ -168,23 +264,56 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx
,
thy
)
thx
,
thy
)
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return
torch
.
ops
.
_C
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def
awq_gemm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
group_size
:
int
,
padding_group
:
int
,
splikspace
:
torch
.
Tensor
,
splikspacesize
:
int
)
->
torch
.
Tensor
:
return
quant_ops
.
awq_gemm
(
input
,
weight
,
zeros_and_scales
,
m
,
n
,
k
,
group_size
,
padding_group
,
splikspace
,
splikspacesize
)
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
return
quant_ops
.
convert_s4
(
qw
,
qz
,
s
,
group_size
)
def
sz_permute
(
sz
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
quant_ops
.
sz_permute
(
sz
)
def
dequant_w4_gemm_colmajor
(
qweight
:
torch
.
Tensor
,
zeros_and_scale
:
torch
.
Tensor
,
k
:
int
,
n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
return
quant_ops
.
dequant_w4_gemm_colmajor
(
qweight
,
zeros_and_scale
,
k
,
n
,
group_size
)
# gptq
# gptq
def
gptq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
bit
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
return
quant_ops
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
b_g_idx
,
use_exllama
,
bit
)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
bit
:
int
)
->
None
:
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
# squeezellm
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
7462218e
...
@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
if
self
.
use_flash_attn_auto
:
triton_attention
)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
self
.
attn_func
=
triton_attention
flash_attn_varlen_func
)
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_ck
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
# either
...
@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
if
self
.
use_flash_attn_auto
:
query
,
if
prefill_meta
.
max_prefill_seq_len
>=
8000
:
key
,
out
=
self
.
attn_func_triton
(
value
,
q
=
query
,
None
,
k
=
key
,
prefill_meta
.
seq_start_loc
,
v
=
value
,
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
True
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
)
else
:
out
=
self
.
attn_func_ck
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
# Interleave for MQA workaround.
...
...
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
0 → 100644
View file @
7462218e
This diff is collapsed.
Click to expand it.
vllm/attention/ops/paged_attn.py
View file @
7462218e
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
import
vllm.envs
as
envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
...
@@ -122,6 +123,33 @@ class PagedAttention:
...
@@ -122,6 +123,33 @@ class PagedAttention:
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v1_opt
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
output
,
output
,
query
,
query
,
...
@@ -156,6 +184,38 @@ class PagedAttention:
...
@@ -156,6 +184,38 @@ class PagedAttention:
device
=
output
.
device
,
device
=
output
.
device
,
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V2 SIZE:"
)
print
(
f
"exp_sums.shape =
{
exp_sums
.
shape
}
, max_logits.shape =
{
max_logits
.
shape
}
, tmp_output.shape =
{
tmp_output
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
VLLM_USE_OPT_OP
:
ops
.
paged_attention_v2_opt
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
output
,
output
,
exp_sums
,
exp_sums
,
...
...
vllm/attention/ops/prefix_prefill.py
View file @
7462218e
...
@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window
=
None
):
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
torch
.
cuda
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
BLOCK
=
32
if
cap
[
0
]
>=
8
else
32
# shape constraints
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
...
@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if
sliding_window
is
None
or
sliding_window
<=
0
:
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
sliding_window
=
0
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
4
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
...
...
vllm/benchmark_throughput.py
0 → 100644
View file @
7462218e
This diff is collapsed.
Click to expand it.
vllm/config.py
View file @
7462218e
...
@@ -173,7 +173,7 @@ class ModelConfig:
...
@@ -173,7 +173,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
@@ -279,6 +279,12 @@ class ModelConfig:
...
@@ -279,6 +279,12 @@ class ModelConfig:
return
self
.
hf_text_config
.
hidden_size
return
self
.
hf_text_config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
# TODO remove hard code
if
hasattr
(
self
.
hf_text_config
,
"model_type"
)
and
self
.
hf_text_config
.
model_type
==
'deepseek_v2'
:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return
256
if
hasattr
(
self
.
hf_text_config
,
"head_dim"
):
if
hasattr
(
self
.
hf_text_config
,
"head_dim"
):
return
self
.
hf_text_config
.
head_dim
return
self
.
hf_text_config
.
head_dim
# FIXME(woosuk): This may not be true for all models.
# FIXME(woosuk): This may not be true for all models.
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
7462218e
vllm/engine/llm_engine.py
View file @
7462218e
...
@@ -232,6 +232,8 @@ class LLMEngine:
...
@@ -232,6 +232,8 @@ class LLMEngine:
load_config
=
load_config
,
load_config
=
load_config
,
)
)
init_success
=
False
try
:
if
not
self
.
model_config
.
embedding_mode
:
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
self
.
_initialize_kv_caches
()
...
@@ -288,6 +290,13 @@ class LLMEngine:
...
@@ -288,6 +290,13 @@ class LLMEngine:
max_model_len
=
self
.
model_config
.
max_model_len
)
max_model_len
=
self
.
model_config
.
max_model_len
)
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
tokenizer_group
=
self
.
get_tokenizer_group
()
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
# Create sequence output processor, e.g. for beam search or
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
# speculative decoding.
self
.
output_processor
=
(
self
.
output_processor
=
(
...
@@ -296,12 +305,18 @@ class LLMEngine:
...
@@ -296,12 +305,18 @@ class LLMEngine:
self
.
detokenizer
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
scheduler
,
self
.
seq_counter
,
self
.
seq_counter
,
self
.
get_tokenizer_for_seq
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
stop_checker
=
StopChecker
(
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_model_len
,
self
.
get_tokenizer_for_seq
,
get_tokenizer_for_seq
,
),
),
))
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -393,10 +408,10 @@ class LLMEngine:
...
@@ -393,10 +408,10 @@ class LLMEngine:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
None
)
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
None
)
def
get_tokenizer_for_seq
(
self
,
#
def get_tokenizer_for_seq(self,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
#
sequence: Sequence) -> "PreTrainedTokenizer":
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
#
return self.get_tokenizer_group().get_lora_tokenizer(
sequence
.
lora_request
)
#
sequence.lora_request)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
init_kwargs
=
dict
(
init_kwargs
=
dict
(
...
@@ -785,7 +800,8 @@ class LLMEngine:
...
@@ -785,7 +800,8 @@ class LLMEngine:
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
if
not
request_outputs
:
# if not request_outputs:
if
not
self
.
has_unfinished_requests
():
# Stop the execute model loop in parallel workers until there are
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# torch.distributed ops which may otherwise timeout, and unblocks
...
...
vllm/envs.py
View file @
7462218e
...
@@ -9,6 +9,9 @@ if TYPE_CHECKING:
...
@@ -9,6 +9,9 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
...
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_
XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_
FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
...
@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# timeout for each iteration in the engine
# timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
:
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"
6
0"
)),
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"
12
0"
)),
# API key for VLLM API server
# API key for VLLM API server
"VLLM_API_KEY"
:
"VLLM_API_KEY"
:
...
@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD"
:
"VLLM_WORKER_MULTIPROC_METHOD"
:
lambda
:
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
),
lambda
:
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"65536"
)),
# Timeout for fetching images when serving multimodal models
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
"VLLM_IMAGE_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/executor/multiproc_worker_utils.py
View file @
7462218e
...
@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
...
@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
True
)
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
self
.
result_queue
=
mp
.
Queue
()
self
.
result_queue
=
mp
.
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
...
@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
...
@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
True
)
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
self
.
workers
=
workers
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
result_handler
=
result_handler
self
.
_close
=
False
self
.
_close
=
False
...
@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
...
@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self
.
_close
=
True
self
.
_close
=
True
# Kill / cleanup all workers
# Kill / cleanup all workers
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if
not
sys
.
is_finalizing
():
# Kill / cleanup all workers
died_count
=
0
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
process
=
worker
.
process
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
died_count
+=
1
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
process
.
name
,
process
.
pid
,
process
.
exitcode
)
if
died_count
<
len
(
self
.
workers
):
logger
.
info
(
"Killing remaining local vLLM worker processes"
)
# Cleanup any remaining workers
# Cleanup any remaining workers
logger
.
info
(
"Killing local vLLM worker processes"
)
#
logger.info("Killing local vLLM worker processes")
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
# Must be done after worker task queues are all closed
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment