Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7462218e
Commit
7462218e
authored
Sep 05, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.5.0-dtk24.04.1'
parents
6ccd3f47
1cec5e62
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2595 additions
and
327 deletions
+2595
-327
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+65
-0
examples/offline_inference.py
examples/offline_inference.py
+20
-19
examples/offline_streaming_inference_chat_demo.py
examples/offline_streaming_inference_chat_demo.py
+110
-0
examples/template_llama_chat.jinja
examples/template_llama_chat.jinja
+24
-0
requirements-rocm.txt
requirements-rocm.txt
+1
-2
setup.py
setup.py
+12
-22
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+8
-4
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-0
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+119
-115
vllm/_custom_ops.py
vllm/_custom_ops.py
+137
-8
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+68
-16
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
+1308
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+103
-43
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+2
-2
vllm/benchmark_throughput.py
vllm/benchmark_throughput.py
+464
-0
vllm/config.py
vllm/config.py
+7
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+91
-75
vllm/envs.py
vllm/envs.py
+25
-9
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+28
-10
No files found.
csrc/torch_bindings.cpp
View file @
7462218e
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops
.
def
(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1_opt"
,
torch
::
kCUDA
,
&
paged_attention_v1_opt
);
// PagedAttention V2 (opt).
ops
.
def
(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt
);
// Activation ops
// Activation ops
// Activation function used in SwiGLU.
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
...
@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_fast(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
ops
.
impl
(
"gelu_fast"
,
torch
::
kCUDA
,
&
gelu_fast
);
// Activation function used in SwiGLU. (opt)
ops
.
def
(
"silu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"silu_and_mul_opt"
,
torch
::
kCUDA
,
&
silu_and_mul_opt
);
// Activation function used in GeGLU with `none` approximation. (opt)
ops
.
def
(
"gelu_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_and_mul_opt
);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops
.
def
(
"gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_tanh_and_mul_opt"
,
torch
::
kCUDA
,
&
gelu_tanh_and_mul_opt
);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
ops
.
def
(
...
@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops
.
def
(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_opt"
,
torch
::
kCUDA
,
&
rms_norm_opt
);
// In-place fused Add and RMS Normalization. (opt)
ops
.
def
(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_opt"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_opt
);
// Rotary embedding
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
ops
.
def
(
...
@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()"
);
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops
.
impl
(
"rotary_embedding_tgi"
,
torch
::
kCUDA
,
&
rotary_embedding_tgi
);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
// (supports multiple loras).
ops
.
def
(
ops
.
def
(
...
@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"
);
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// trans w16
ops
.
def
(
"trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"
);
ops
.
impl
(
"trans_w16_gemm"
,
torch
::
kCUDA
,
&
trans_w16_gemm
);
// Quantization ops
// Quantization ops
#ifndef USE_ROCM
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
// Quantized GEMM for AQLM.
...
...
examples/offline_inference.py
View file @
7462218e
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
if
__name__
==
'__main__'
:
prompts
=
[
# Sample prompts.
"Hello, my name is"
,
prompts
=
[
"The president of the United States is"
,
"Hello, my name is"
,
"The capital of France is"
,
"The president of the United States is"
,
"The future of AI is"
,
"The capital of France is"
,
]
"The future of AI is"
,
# Create a sampling params object.
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
16
)
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
trust_remote_code
=
True
,
dtype
=
"float16"
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
t
ensor_parallel_size
=
1
,
distributed_executor_backend
=
"ray"
,
dtype
=
"float16"
,
t
rust_remote_code
=
True
,
enforce_eager
=
True
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
examples/offline_streaming_inference_chat_demo.py
0 → 100644
View file @
7462218e
from
vllm.sampling_params
import
SamplingParams
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
import
asyncio
from
vllm.utils
import
FlexibleArgumentParser
from
transformers
import
AutoTokenizer
import
logging
import
argparse
import
sys
vllm_logger
=
logging
.
getLogger
(
"vllm"
)
vllm_logger
.
setLevel
(
logging
.
WARNING
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
"""ArgumentParser that allows both underscore and dash in names."""
def
parse_args
(
self
,
args
=
None
,
namespace
=
None
):
if
args
is
None
:
args
=
sys
.
argv
[
1
:]
# Convert underscores to dashes and vice versa in argument names
processed_args
=
[]
for
arg
in
args
:
if
arg
.
startswith
(
'--'
):
if
'='
in
arg
:
key
,
value
=
arg
.
split
(
'='
,
1
)
key
=
'--'
+
key
[
len
(
'--'
):].
replace
(
'_'
,
'-'
)
processed_args
.
append
(
f
'
{
key
}
=
{
value
}
'
)
else
:
processed_args
.
append
(
'--'
+
arg
[
len
(
'--'
):].
replace
(
'_'
,
'-'
))
else
:
processed_args
.
append
(
arg
)
return
super
().
parse_args
(
processed_args
,
namespace
)
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
'--template'
,
type
=
str
,
help
=
"Path to template"
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model
)
try
:
f
=
open
(
args
.
template
,
'r'
)
tokenizer
.
chat_template
=
f
.
read
()
except
Exception
as
e
:
print
(
'except:'
,
e
)
finally
:
f
.
close
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
model_name
=
args
.
model
.
split
(
"/"
)[
-
1
]
if
args
.
model
.
split
(
"/"
)[
-
1
]
!=
""
else
args
.
model
.
split
(
"/"
)[
-
2
]
print
(
f
"欢迎使用
{
model_name
}
模型,输入内容即可进行对话,stop 终止程序"
)
def
build_prompt
(
history
):
prompt
=
""
for
query
,
response
in
history
:
prompt
+=
f
"
\n\n
用户:
{
query
}
"
prompt
+=
f
"
\n\n
{
model_name
}
:
{
response
}
"
return
prompt
history
=
[]
while
True
:
query
=
input
(
"
\n
用户:"
)
if
query
.
strip
()
==
"stop"
:
break
history
.
append
({
"role"
:
"user"
,
"content"
:
query
})
new_query
=
tokenizer
.
apply_chat_template
(
history
,
tokenize
=
False
)
example_input
=
{
"prompt"
:
new_query
,
"stream"
:
False
,
"temperature"
:
0.0
,
"request_id"
:
0
,
}
results_generator
=
engine
.
generate
(
example_input
[
"prompt"
],
SamplingParams
(
temperature
=
example_input
[
"temperature"
],
max_tokens
=
100
),
example_input
[
"request_id"
]
)
start
=
0
end
=
0
response
=
""
async
def
process_results
():
async
for
output
in
results_generator
:
global
end
global
start
global
response
print
(
output
.
outputs
[
0
].
text
[
start
:],
end
=
""
,
flush
=
True
)
length
=
len
(
output
.
outputs
[
0
].
text
)
start
=
length
response
=
output
.
outputs
[
0
].
text
asyncio
.
run
(
process_results
())
history
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
print
()
examples/template_llama_chat.jinja
0 → 100644
View file @
7462218e
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
requirements-rocm.txt
View file @
7462218e
...
@@ -2,6 +2,5 @@
...
@@ -2,6 +2,5 @@
-r requirements-common.txt
-r requirements-common.txt
# Dependencies for AMD GPUs
# Dependencies for AMD GPUs
ray == 2.9.1
ray >= 2.10.0
# ray >= 2.10.0
pytest-asyncio
pytest-asyncio
setup.py
View file @
7462218e
...
@@ -18,6 +18,9 @@ from typing import Optional, Union
...
@@ -18,6 +18,9 @@ from typing import Optional, Union
import
subprocess
import
subprocess
from
pathlib
import
Path
from
pathlib
import
Path
add_git_version
=
False
if
int
(
os
.
environ
.
get
(
'ADD_GIT_VERSION'
,
'0'
))
==
1
:
add_git_version
=
True
def
load_module_from_path
(
module_name
,
path
):
def
load_module_from_path
(
module_name
,
path
):
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
path
)
...
@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
...
@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise
RuntimeError
(
"Unable to find version string."
)
raise
RuntimeError
(
"Unable to find version string."
)
def
get_abi
():
try
:
command
=
"echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result
=
subprocess
.
run
(
command
,
shell
=
True
,
capture_output
=
True
,
text
=
True
)
output
=
result
.
stdout
.
strip
()
abi
=
"abi"
+
output
.
split
(
" "
)[
-
1
]
return
abi
except
Exception
:
return
'abiUnknown'
def
get_sha
(
root
:
Union
[
str
,
Path
])
->
str
:
def
get_sha
(
root
:
Union
[
str
,
Path
])
->
str
:
try
:
try
:
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
root
).
decode
(
'ascii'
).
strip
()
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
root
).
decode
(
'ascii'
).
strip
()
except
Exception
:
except
Exception
:
return
'Unknown'
return
'Unknown'
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
vllm_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
vllm_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
vllm_root
,
"vllm"
),
"version.py"
)
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
vllm_root
,
"vllm"
),
"version.py"
)
if
sha
!=
'Unknown'
:
if
add_git_version
:
if
sha
is
None
:
if
sha
!=
'Unknown'
:
sha
=
get_sha
(
vllm_root
)
if
sha
is
None
:
version
=
'das1.1.git'
+
sha
[:
7
]
sha
=
get_sha
(
vllm_root
)
version
=
'das.opt1'
+
sha
[:
7
]
# abi version
else
:
version
+
=
"."
+
get_abi
()
version
=
'das.opt1'
# dtk version
# dtk version
if
os
.
getenv
(
"ROCM_PATH"
):
if
os
.
getenv
(
"ROCM_PATH"
):
...
@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path
=
os
.
path
.
join
(
rocm_path
,
'.info'
,
"rocm_version"
)
rocm_version_path
=
os
.
path
.
join
(
rocm_path
,
'.info'
,
"rocm_version"
)
with
open
(
rocm_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
with
open
(
rocm_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
lines
=
file
.
readlines
()
lines
=
file
.
readlines
()
rocm_version
=
lines
[
0
]
[:
-
2
]
.
replace
(
"."
,
""
)
rocm_version
=
lines
[
0
].
replace
(
"."
,
""
)
version
+=
".dtk"
+
rocm_version
version
+=
".dtk"
+
rocm_version
# torch version
version
+=
".torch"
+
torch
.
__version__
[:
5
]
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
file
.
write
(
"__version__='0.5.0.post1'
\n
"
)
file
.
write
(
"__version__='0.5.0.post1'
\n
"
)
file
.
write
(
"__dcu_version__='0.5.0.post1+{}'
\n
"
.
format
(
version
))
file
.
write
(
"__dcu_version__='0.5.0.post1+{}'
\n
"
.
format
(
version
))
...
...
tests/basic_correctness/test_preemption.py
View file @
7462218e
...
@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
...
@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
def
test_preemption
(
def
test_preemption
(
caplog_vllm
,
caplog_vllm
,
...
@@ -118,7 +119,8 @@ def test_preemption(
...
@@ -118,7 +119,8 @@ def test_preemption(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
def
test_swap
(
def
test_swap
(
...
@@ -176,7 +178,8 @@ def test_swap(
...
@@ -176,7 +178,8 @@ def test_swap(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
def
test_swap_infeasible
(
def
test_swap_infeasible
(
...
@@ -220,7 +223,8 @@ def test_swap_infeasible(
...
@@ -220,7 +223,8 @@ def test_swap_infeasible(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
# @pytest.mark.parametrize("dtype", ["float"])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
def
test_preemption_infeasible
(
def
test_preemption_infeasible
(
vllm_runner
,
vllm_runner
,
...
...
tests/kernels/test_attention.py
View file @
7462218e
...
@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
...
@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
xops
.
fmha
.
MemoryEfficientAttentionFlashAttentionOp
[
0
]
if
(
is_hip
())
else
None
,
)
)
output
=
output
.
squeeze
(
0
)
output
=
output
.
squeeze
(
0
)
...
...
tests/kernels/test_prefix_prefill.py
View file @
7462218e
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
...
@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.utils
import
is_hip
NUM_HEADS
=
[
64
]
NUM_HEADS
=
[
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
NUM_QUERIES_PER_KV
=
[
1
,
8
,
64
]
...
@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
...
@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
attn_op
=
xops
.
fmha
.
cutlass
.
FwOp
()
if
num_kv_heads
!=
num_heads
:
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# project the key and value tensors to the desired number of
# heads.
# heads.
#
#
# see also: vllm/model_executor/layers/attention.py
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
attn_bias
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
)
query_lens
,
seq_lens
)
if
sliding_window
>
0
:
if
sliding_window
>
0
:
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
attn_bias
=
attn_bias
.
make_local_attention_from_bottomright
(
sliding_window
)
sliding_window
)
output_ref
=
xops
.
memory_efficient_attention_forward
(
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
,
query
,
key
,
key
,
value
,
value
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
attn_op
,
op
=
attn_op
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
output_ref
=
xops
.
memory_efficient_attention_forward
(
output_ref
=
xops
.
memory_efficient_attention_forward
(
query
,
query
,
key
,
key
,
value
,
value
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
p
=
0.0
,
scale
=
scale
,
scale
=
scale
,
op
=
attn_op
,
op
=
attn_op
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
output_ref
=
output_ref
.
reshape
(
output
.
shape
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
...
@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"triton Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
not
is_hip
():
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if
query
.
shape
[
0
]
!=
key
.
shape
[
0
]:
query_pad
=
torch
.
empty
(
sum
(
seq_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
query_pad
.
uniform_
(
-
1e-3
,
1e-3
)
seq_start
=
0
query_start
=
0
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_pad
[
seq_start
:
seq_end
,
...]
=
torch
.
cat
([
torch
.
zeros
(
seq_len
-
query_len
,
num_heads
,
head_size
,
dtype
=
dtype
),
query
[
query_start
:
query_end
,
...]
],
dim
=
0
)
seq_start
+=
seq_len
query_start
+=
query_len
query
=
query_pad
if
num_kv_heads
!=
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
# we have to pad query tensor before MQA/GQA expanding.
output_ref
=
torch
.
empty_like
(
output
)
if
query
.
shape
[
0
]
!=
key
.
shape
[
0
]:
query_pad
=
torch
.
empty
(
sum
(
seq_lens
),
num_heads
,
head_size
,
dtype
=
dtype
)
query_pad
.
uniform_
(
-
1e-3
,
1e-3
)
seq_start
=
0
seq_start
=
0
query_start
=
0
query_start
=
0
start_time
=
time
.
time
()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_end
=
query_start
+
query_len
query_pad
[
seq_start
:
seq_end
,
...]
=
torch
.
cat
([
out
=
xops
.
memory_efficient_attention_forward
(
query
[:,
torch
.
zeros
(
seq_start
:
seq_end
],
seq_len
-
query_len
,
num_heads
,
head_size
,
dtype
=
dtype
),
key
[:,
query
[
query_start
:
query_end
,
...]
seq_start
:
seq_end
],
],
value
[:,
dim
=
0
)
seq_start
:
seq_end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
out
=
out
.
view_as
(
query
[:,
seq_start
:
seq_end
]).
view
(
seq_len
,
num_heads
,
head_size
)
output_ref
[
query_start
:
query_end
,
...].
copy_
(
out
[
seq_len
-
query_len
:,
...])
seq_start
+=
seq_len
seq_start
+=
seq_len
query_start
+=
query_len
query_start
+=
query_len
query
=
query_pad
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
if
num_kv_heads
!=
num_heads
:
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query
=
query
.
view
(
query
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
num_kv_heads
,
num_queries_per_kv
,
value
.
shape
[
-
1
])
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output_ref
=
torch
.
empty_like
(
output
)
seq_start
=
0
query_start
=
0
start_time
=
time
.
time
()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[:,
seq_start
:
seq_end
],
key
[:,
seq_start
:
seq_end
],
value
[:,
seq_start
:
seq_end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
out
=
out
.
view_as
(
query
[:,
seq_start
:
seq_end
]).
view
(
seq_len
,
num_heads
,
head_size
)
output_ref
[
query_start
:
query_end
,
...].
copy_
(
out
[
seq_len
-
query_len
:,
...])
seq_start
+=
seq_len
query_start
+=
query_len
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"xformers Time:
{
(
end_time
-
start_time
)
*
1000
:.
2
f
}
ms"
)
assert
torch
.
allclose
(
output_ref
,
output
,
atol
=
1e-6
,
rtol
=
0
)
vllm/_custom_ops.py
View file @
7462218e
import
contextlib
import
contextlib
import
functools
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
from
lmslim
import
quant_ops
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
try
:
try
:
import
vllm._C
import
vllm._C
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
...
@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_tanh_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
torch
.
ops
.
_C
.
gelu_tanh_and_mul
(
out
,
x
)
def
silu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul_opt
(
out
,
x
)
def
gelu_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_and_mul_opt
(
out
,
x
)
def
gelu_tanh_and_mul_opt
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_tanh_and_mul_opt
(
out
,
x
)
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
def
gelu_fast
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
...
@@ -125,6 +141,65 @@ def paged_attention_v2(
...
@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# page attention ops (opt)
def
paged_attention_v1_opt
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v1_opt
(
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2_opt
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
torch
.
ops
.
_C
.
paged_attention_v2_opt
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# pos encoding ops
# pos encoding ops
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
...
@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
def
fused_add_rms_norm
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
def
rms_norm_opt
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
rms_norm_opt
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm_opt
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm_opt
(
input
,
residual
,
weight
,
epsilon
)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
row
:
int
,
col
:
int
)
->
None
:
torch
.
ops
.
_C
.
trans_w16_gemm
(
dst
,
src
,
row
,
col
)
# quantization ops
# quantization ops
# awq
# awq
def
GetAWQShareWorkspaceSize
()
->
int
:
return
quant_ops
.
GetAWQShareWorkspaceSize
()
def
GetAWQShareWorkspace
()
->
torch
.
Tensor
:
return
quant_ops
.
GetAWQShareWorkspace
()
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
thy
:
int
)
->
torch
.
Tensor
:
...
@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
...
@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx
,
thy
)
thx
,
thy
)
def
awq_gemm
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return
torch
.
ops
.
_C
.
awq_gemm
(
input
,
qweight
,
qzeros
,
scales
,
split_k_iters
)
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def
awq_gemm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
group_size
:
int
,
padding_group
:
int
,
splikspace
:
torch
.
Tensor
,
splikspacesize
:
int
)
->
torch
.
Tensor
:
return
quant_ops
.
awq_gemm
(
input
,
weight
,
zeros_and_scales
,
m
,
n
,
k
,
group_size
,
padding_group
,
splikspace
,
splikspacesize
)
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
return
quant_ops
.
convert_s4
(
qw
,
qz
,
s
,
group_size
)
def
sz_permute
(
sz
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
quant_ops
.
sz_permute
(
sz
)
def
dequant_w4_gemm_colmajor
(
qweight
:
torch
.
Tensor
,
zeros_and_scale
:
torch
.
Tensor
,
k
:
int
,
n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
return
quant_ops
.
dequant_w4_gemm_colmajor
(
qweight
,
zeros_and_scale
,
k
,
n
,
group_size
)
# gptq
# gptq
def
gptq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
bit
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
return
quant_ops
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
b_g_idx
,
use_exllama
,
bit
)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
bit
:
int
)
->
None
:
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
quant_ops
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
7462218e
...
@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
if
self
.
use_flash_attn_auto
:
triton_attention
)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
self
.
attn_func
=
triton_attention
flash_attn_varlen_func
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_ck
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
# either
...
@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
out
,
_
=
self
.
attn_func
(
if
self
.
use_flash_attn_auto
:
query
,
if
prefill_meta
.
max_prefill_seq_len
>=
8000
:
key
,
out
=
self
.
attn_func_triton
(
value
,
q
=
query
,
None
,
k
=
key
,
prefill_meta
.
seq_start_loc
,
v
=
value
,
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
True
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
out
=
self
.
attn_func_ck
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
# Interleave for MQA workaround.
...
@@ -439,4 +491,4 @@ def _sdpa_attention(
...
@@ -439,4 +491,4 @@ def _sdpa_attention(
output
[
start
:
end
,
:,
:]
=
sub_out
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
start
=
end
return
output
return
output
\ No newline at end of file
vllm/attention/ops/flash_attn_triton_mqa_gqa.py
0 → 100644
View file @
7462218e
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import
argparse
import
pytest
import
random
import
sys
import
torch
import
triton
import
triton.language
as
tl
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
TORCH_HAS_FP8E5
=
hasattr
(
torch
,
'float8_e5m2fnuz'
)
if
TORCH_HAS_FP8E5
:
torch_dtype
:
tl
.
constexpr
=
torch
.
float8_e5m2fnuz
class
MetaData
():
cu_seqlens_q
=
None
cu_seqlens_k
=
None
max_seqlens_q
=
0
max_seqlens_k
=
0
bias
=
None
alibi_slopes
=
None
causal
=
False
num_contexts
=
0
varlen
=
False
dropout_p
,
return_encoded_softmax
=
0.0
,
False
def
__init__
(
self
,
sm_scale
=
1.0
,
causal
=
False
,
dropout_p
=
0.0
,
return_encoded_softmax
=
False
):
self
.
sm_scale
=
sm_scale
self
.
causal
=
causal
self
.
dropout_p
=
dropout_p
self
.
return_encoded_softmax
=
return_encoded_softmax
def
set_varlen_params
(
self
,
cu_seqlens_q
,
cu_seqlens_k
):
self
.
varlen
=
True
self
.
cu_seqlens_q
=
cu_seqlens_q
self
.
cu_seqlens_k
=
cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert
len
(
cu_seqlens_q
)
>=
2
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
)
self
.
num_contexts
=
len
(
cu_seqlens_q
)
-
1
for
i
in
range
(
0
,
self
.
num_contexts
):
self
.
max_seqlens_q
=
max
(
cu_seqlens_q
[
i
+
1
].
item
()
-
cu_seqlens_q
[
i
].
item
(),
self
.
max_seqlens_q
)
self
.
max_seqlens_k
=
max
(
cu_seqlens_k
[
i
+
1
].
item
()
-
cu_seqlens_k
[
i
].
item
(),
self
.
max_seqlens_k
)
def
need_bias
(
self
,
bias
,
batch
,
nheads
,
seqlen_q
,
seqlen_k
):
assert
bias
.
is_cuda
assert
bias
.
dim
()
==
4
assert
bias
.
shape
[
0
]
==
1
assert
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
)
self
.
bias
=
bias
def
need_alibi
(
self
,
alibi_slopes
,
batch
,
nheads
):
assert
alibi_slopes
.
is_cuda
assert
alibi_slopes
.
dim
()
==
2
assert
alibi_slopes
.
shape
[
0
]
==
batch
assert
alibi_slopes
.
shape
[
1
]
==
nheads
self
.
alibi_slopes
=
alibi_slopes
def
need_causal
(
self
):
self
.
causal
=
True
def
need_dropout
(
dropout_p
,
return_encoded_softmax
):
self
.
dropout_p
=
dropout_p
self
.
return_encoded_softmax
=
return_encoded_softmax
def
check_args
(
self
,
q
,
k
,
v
,
o
):
assert
q
.
dim
()
==
k
.
dim
()
and
q
.
dim
()
==
v
.
dim
()
if
self
.
varlen
:
assert
q
.
dim
()
==
3
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
assert
self
.
cu_seqlens_q
is
not
None
assert
self
.
cu_seqlens_k
is
not
None
assert
len
(
self
.
cu_seqlens_q
)
==
len
(
self
.
cu_seqlens_k
)
# TODO: Remove once bias is supported with varlen
assert
self
.
bias
==
None
# TODO:Remove once dropout is supported with varlen
assert
self
.
dropout_p
==
0.0
assert
not
self
.
return_encoded_softmax
else
:
assert
q
.
dim
()
==
4
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
assert
self
.
max_seqlens_q
>
0
and
self
.
max_seqlens_k
>
0
assert
self
.
cu_seqlens_q
is
None
and
self
.
cu_seqlens_k
is
None
assert
k
.
shape
==
v
.
shape
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
# TODO: Change assert if we support qkl f8 and v f16
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
assert
head_size
<=
256
assert
o
.
shape
==
q
.
shape
assert
(
nheads_q
%
nheads_k
)
==
0
@
triton
.
jit
def
cdiv_fn
(
x
,
y
):
return
(
x
+
y
-
1
)
//
y
@
triton
.
jit
def
max_fn
(
x
,
y
):
return
tl
.
math
.
max
(
x
,
y
)
@
triton
.
jit
def
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
ms
=
tl
.
arange
(
0
,
m
)
ns
=
tl
.
arange
(
0
,
n
)
return
philox_offset
+
ms
[:,
None
]
*
stride
+
ns
[
None
,
:]
@
triton
.
jit
def
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_offsets
=
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
).
to
(
tl
.
uint32
)
# TODO: use tl.randint for better performance
return
tl
.
rand
(
philox_seed
,
rng_offsets
)
@
triton
.
jit
def
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_output
=
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
)
rng_keep
=
rng_output
>
dropout_p
return
rng_keep
@
triton
.
jit
def
load_fn
(
block_ptr
,
first
,
second
,
pad
):
if
first
and
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
pad
)
elif
first
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,),
padding_option
=
pad
)
elif
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
1
,),
padding_option
=
pad
)
else
:
tensor
=
tl
.
load
(
block_ptr
)
return
tensor
@
triton
.
jit
def
print_gpu
(
prefix
,
val
=
None
):
if
(
tl
.
program_id
(
0
)
==
0
)
and
((
tl
.
program_id
(
1
)
==
0
)
and
(
tl
.
program_id
(
2
)
==
0
)):
if
val
is
not
None
:
tl
.
device_print
(
prefix
,
val
)
else
:
tl
.
device_print
(
prefix
)
@
triton
.
jit
def
compute_alibi_block
(
alibi_slope
,
seqlen_q
,
seqlen_k
,
offs_m
,
offs_n
,
transpose
=
False
):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block
=
offs_m
[:,
None
]
+
seqlen_k
-
seqlen_q
-
offs_n
[
None
,:]
alibi_block
=
-
1
*
alibi_slope
*
tl
.
abs
(
relative_pos_block
)
if
transpose
:
return
alibi_block
.
T
else
:
return
alibi_block
@
triton
.
jit
def
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
actual_seqlen_k
,
actual_seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
alibi_slope
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
OFFS_M
:
tl
.
constexpr
,
OFFS_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
MASK_STEPS
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
PADDED_HEAD
:
tl
.
constexpr
):
# loop over k, v, and update accumulator
for
start_n
in
range
(
block_min
,
block_max
,
BLOCK_N
):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k
=
load_fn
(
K_block_ptr
,
PADDED_HEAD
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
if
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if
MASK_STEPS
:
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
# last step might get wasted but that is okay. check if this masking works For
# that case.
if
(
start_n
+
BLOCK_N
==
block_max
)
and
(
n_extra_tokens
!=
0
):
boundary_m
=
tl
.
full
([
BLOCK_M
],
actual_seqlen_k
,
dtype
=
tl
.
int32
)
size_n
=
start_n
+
OFFS_N
[
None
,:]
mask
=
size_n
<
boundary_m
[:,
None
]
qk
=
tl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
causal_boundary
=
start_n
+
offs_n_causal
causal_mask
=
OFFS_M
[:,
None
]
>=
causal_boundary
[
None
,
:]
qk
=
tl
.
where
(
causal_mask
,
qk
,
float
(
"-inf"
))
# -- compute qk ----
qk
+=
tl
.
dot
(
q
,
k
)
if
bias_ptr
is
not
None
:
bias
=
load_fn
(
bias_ptr
,
False
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk
+=
(
bias
*
1.44269504089
)
if
alibi_slope
is
not
None
:
# Compute the global position of each token within the sequence
global_m_positions
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
global_n_positions
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
alibi_block
=
compute_alibi_block
(
alibi_slope
,
actual_seqlen_q
,
actual_seqlen_k
,
global_m_positions
,
global_n_positions
)
qk
+=
(
alibi_block
*
1.44269504089
)
# scale factor of log2(e)
# softmax
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
=
qk
-
m_ij
[:,
None
]
p
=
tl
.
math
.
exp2
(
qk
)
# CAVEAT: Must update l_ij before applying dropout
l_ij
=
tl
.
sum
(
p
,
1
)
if
ENABLE_DROPOUT
:
philox_offset
=
batch_philox_offset
+
start_m
*
BLOCK_M
*
actual_seqlen_k
+
start_n
-
BLOCK_N
keep
=
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
BLOCK_M
,
BLOCK_N
,
actual_seqlen_k
)
if
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
tl
.
where
(
keep
,
p
,
-
p
).
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
))
p
=
tl
.
where
(
keep
,
p
,
0.0
)
elif
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
p
.
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
))
# -- update output accumulator --
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
if
not
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
)
# -- update m_i and l_i
l_i
=
l_i
*
alpha
+
l_ij
# update m_i and l_i
m_i
=
m_ij
acc
+=
tl
.
dot
(
p
.
to
(
V_block_ptr
.
type
.
element_ty
),
v
)
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
BLOCK_N
,
0
))
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
BLOCK_N
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
BLOCK_N
))
return
acc
,
l_i
,
m_i
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
2
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
2
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
16
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
True
},
num_stages
=
1
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
32
,
'BLOCK_N'
:
32
,
'waves_per_eu'
:
0
,
'PRE_LOAD_V'
:
False
},
num_stages
=
1
,
num_warps
=
8
),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
],
key
=
[
'IS_CAUSAL'
,
'dropout_p'
,
'BLOCK_DMODEL'
],
# use_cuda_graph=True,
)
@
triton
.
jit
def
attn_fwd
(
Q
,
K
,
V
,
bias
,
sm_scale
,
L
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_bz
,
stride_bh
,
stride_bm
,
stride_bn
,
stride_az
,
stride_ah
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
,
philox_seed
,
philox_offset_base
,
encoded_softmax
,
alibi_slopes
,
HQ
:
tl
.
constexpr
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
VARLEN
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
USE_ALIBI
:
tl
.
constexpr
,
BATCH_SIZE
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_h_q
=
tl
.
program_id
(
1
)
off_z
=
tl
.
program_id
(
2
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
if
VARLEN
:
cu_seqlens_q_start
=
tl
.
load
(
cu_seqlens_q
+
off_z
)
cu_seqlens_q_end
=
tl
.
load
(
cu_seqlens_q
+
off_z
+
1
)
seqlen_q
=
cu_seqlens_q_end
-
cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if
start_m
*
BLOCK_M
>
seqlen_q
:
return
cu_seqlens_k_start
=
tl
.
load
(
cu_seqlens_k
+
off_z
)
cu_seqlens_k_end
=
tl
.
load
(
cu_seqlens_k
+
off_z
+
1
)
seqlen_k
=
cu_seqlens_k_end
-
cu_seqlens_k_start
else
:
cu_seqlens_q_start
=
0
cu_seqlens_k_start
=
0
seqlen_q
=
MAX_SEQLENS_Q
seqlen_k
=
MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks
=
cdiv_fn
(
seqlen_k
,
BLOCK_N
)
if
(
IS_CAUSAL
):
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn matrix
n_blocks_seqlen
=
cdiv_fn
(
(
start_m
+
1
)
*
BLOCK_M
+
seqlen_k
-
seqlen_q
,
BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks
=
min
(
n_blocks
,
n_blocks_seqlen
)
# If we have no blocks after adjusting for seqlen deltas, this WG is part of
# the blocks that are all 0. We exit early.
if
n_blocks
<=
0
:
o_offset
=
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
Out
.
type
.
element_ty
)
# We still need to write 0s to the result
tl
.
store
(
O_block_ptr
,
acc
.
to
(
Out
.
type
.
element_ty
),
boundary_check
=
(
0
,
1
))
l_ptrs
=
L
+
off_z
*
HQ
*
MAX_SEQLENS_Q
+
off_h_q
*
MAX_SEQLENS_Q
+
offs_m
# We store inf to LSE, not -inf because in the bwd pass, we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.
l
=
tl
.
full
([
BLOCK_M
],
value
=
float
(
"inf"
),
dtype
=
tl
.
float32
)
tl
.
store
(
l_ptrs
,
l
)
# TODO: Should dropout and return encoded softmax be handled here too?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
if
GROUP_SIZE
!=
1
:
off_h_k
=
off_h_q
//
GROUP_SIZE
else
:
off_h_k
=
off_h_q
need_padding
=
False
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
need_padding
=
True
n_extra_tokens
=
BLOCK_N
-
seqlen_k
elif
seqlen_k
%
BLOCK_N
:
need_padding
=
True
n_extra_tokens
=
seqlen_k
%
BLOCK_N
PADDED_HEAD
:
tl
.
constexpr
=
(
ACTUAL_BLOCK_DMODEL
!=
BLOCK_DMODEL
)
# Compute pointers for all the tensors used in this kernel.
q_offset
=
off_z
*
stride_qz
+
off_h_q
*
stride_qh
+
cu_seqlens_q_start
*
stride_qm
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
+
q_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
k_offset
=
off_z
*
stride_kz
+
off_h_k
*
stride_kh
+
cu_seqlens_k_start
*
stride_kn
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
+
k_offset
,
shape
=
(
ACTUAL_BLOCK_DMODEL
,
seqlen_k
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
)
)
v_offset
=
off_z
*
stride_vz
+
off_h_k
*
stride_vh
+
cu_seqlens_k_start
*
stride_vk
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
+
v_offset
,
shape
=
(
seqlen_k
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_vk
,
stride_vn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_N
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
if
BIAS_TYPE
!=
0
:
b_offset
=
off_h_q
*
stride_bh
# Note: this might get large enough to overflow on some configs
bias_ptr
=
tl
.
make_block_ptr
(
base
=
bias
+
b_offset
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
stride_bm
,
stride_bn
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
bias_ptr
=
None
if
USE_ALIBI
:
a_offset
=
off_z
*
stride_az
+
off_h_q
*
stride_ah
alibi_slope
=
tl
.
load
(
alibi_slopes
+
a_offset
)
else
:
alibi_slope
=
None
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
+
off_hz
*
seqlen_q
*
seqlen_k
else
:
batch_philox_offset
=
0
# We can ask to return the dropout mask without actually doing any dropout. In
# this case, we return an invalid pointer so indicate the mask is not valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
make_block_ptr
(
base
=
encoded_softmax
+
off_h_q
*
seqlen_q
*
seqlen_k
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
seqlen_k
,
1
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
)
)
else
:
encoded_softmax_block_ptr
=
0
# initialize pointer to m and l
m_i
=
tl
.
full
([
BLOCK_M
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_M
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale
=
sm_scale
*
1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q
=
load_fn
(
Q_block_ptr
,
True
,
PADDED_HEAD
,
"zero"
)
q
=
(
q
*
qk_scale
).
to
(
Q_block_ptr
.
type
.
element_ty
)
# Here we compute how many full and masked blocks we have.
padded_block_k
=
n_extra_tokens
!=
0
is_modulo_mn
=
not
padded_block_k
and
(
seqlen_q
%
BLOCK_M
==
0
)
if
IS_CAUSAL
:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks
=
BLOCK_M
//
BLOCK_N
+
(
not
is_modulo_mn
)
else
:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks
=
padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
# In this case we might exceed n_blocks so pick the min.
masked_blocks
=
min
(
masked_blocks
,
n_blocks
)
n_full_blocks
=
n_blocks
-
masked_blocks
block_min
=
0
block_max
=
n_blocks
*
BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its actual
# value because there is no masking. Similarly we do not need padding.
if
n_full_blocks
>
0
:
block_max
=
(
n_blocks
-
masked_blocks
)
*
BLOCK_N
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min
,
block_max
,
0
,
0
,
0
,
bias_ptr
,
alibi_slope
,
# IS_CAUSAL, ....
False
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
False
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
)
block_min
=
block_max
block_max
=
n_blocks
*
BLOCK_N
tl
.
debug_barrier
()
# Remaining blocks, if any, are full / not masked.
if
(
masked_blocks
>
0
):
if
IS_CAUSAL
:
offs_n_causal
=
offs_n
+
(
seqlen_q
-
seqlen_k
)
else
:
offs_n_causal
=
0
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
n_full_blocks
*
BLOCK_N
,
0
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
n_full_blocks
))
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
seqlen_q
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
alibi_slope
,
IS_CAUSAL
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
True
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
)
# epilogue
acc
=
acc
/
l_i
[:,
None
]
if
ENABLE_DROPOUT
:
acc
=
acc
/
(
1
-
dropout_p
)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx
=
(
start_m
+
1
)
*
BLOCK_M
start_m_idx
=
start_m
*
BLOCK_M
causal_start_idx
=
seqlen_q
-
seqlen_k
acc
=
acc
.
to
(
Out
.
type
.
element_ty
)
if
IS_CAUSAL
:
if
causal_start_idx
>
start_m_idx
and
causal_start_idx
<
end_m_idx
:
out_mask_boundary
=
tl
.
full
((
BLOCK_DMODEL
,),
causal_start_idx
,
dtype
=
tl
.
int32
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
out_ptrs_mask
=
mask_m_offsets
[:,
None
]
>=
out_mask_boundary
[
None
,
:]
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
l_ptrs
=
L
+
off_z
*
HQ
*
MAX_SEQLENS_Q
+
off_h_q
*
MAX_SEQLENS_Q
+
offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size
=
end_m_idx
-
seqlen_q
if
overflow_size
>
0
:
boundary
=
tl
.
full
((
BLOCK_M
,),
BLOCK_M
-
overflow_size
,
dtype
=
tl
.
int32
)
# This is a > check because mask being 0 blocks the store.
l_ptrs_mask
=
boundary
>
tl
.
arange
(
0
,
BLOCK_M
)
tl
.
store
(
l_ptrs
,
m_i
+
tl
.
math
.
log2
(
l_i
),
mask
=
l_ptrs_mask
)
else
:
tl
.
store
(
l_ptrs
,
m_i
+
tl
.
math
.
log2
(
l_i
))
# write back O
o_offset
=
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl
.
store
(
O_block_ptr
,
acc
,
boundary_check
=
(
0
,
1
))
@
triton
.
jit
def
_attn_bwd_preprocess
(
Out
,
DO
,
Delta
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_doz
,
stride_doh
,
stride_dom
,
stride_don
,
seqlen_q
,
head_dim
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
):
# off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
# off_n = tl.arange(0, D_HEAD)
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
off_h
=
tl
.
program_id
(
1
)
# head index
off_z
=
tl
.
program_id
(
2
)
# batch index
num_h
=
tl
.
num_programs
(
1
)
o_offset
=
off_h
*
stride_oh
+
off_z
*
stride_oz
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
head_dim
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
off_m
,
0
),
block_shape
=
(
BLOCK_M
,
D_HEAD
),
order
=
(
1
,
0
)
)
do_offset
=
off_h
*
stride_doh
+
off_z
*
stride_doz
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
+
do_offset
,
shape
=
(
seqlen_q
,
head_dim
),
strides
=
(
stride_dom
,
stride_don
),
offsets
=
(
off_m
,
0
),
block_shape
=
(
BLOCK_M
,
D_HEAD
),
order
=
(
1
,
0
)
)
# load
# o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
# do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
o
=
tl
.
load
(
O_block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
"zero"
).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO_block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
"zero"
).
to
(
tl
.
float32
)
# compute
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
# write-back, shape (q.shape[0] * q.shape[1], q.shape[2])
off_zh
=
off_z
*
num_h
+
off_h
*
1
# Check for OOB accesses
delta_ptrs
=
Delta
+
off_zh
*
seqlen_q
+
off_m
+
tl
.
arange
(
0
,
BLOCK_M
)
overflow
=
off_m
+
BLOCK_M
-
seqlen_q
if
overflow
>
0
:
boundary
=
tl
.
full
((
BLOCK_M
,
),
BLOCK_M
-
overflow
,
dtype
=
tl
.
int32
)
mask
=
boundary
>
tl
.
arange
(
0
,
BLOCK_M
)
tl
.
store
(
delta_ptrs
,
delta
,
mask
=
mask
)
else
:
tl
.
store
(
delta_ptrs
,
delta
)
@
triton
.
jit
def
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
# shared by Q/K/V/DO.
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M1
:
tl
.
constexpr
,
BLOCK_N1
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# Filled in by the wrapper.
start_n
,
start_m
,
num_steps
,
MASK
:
tl
.
constexpr
):
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M1
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N1
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
QT_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_m
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_M1
),
order
=
(
0
,
1
)
)
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl
.
static_assert
(
BLOCK_N1
%
BLOCK_M1
==
0
)
curr_m
=
start_m
step_m
=
BLOCK_M1
for
blk_idx
in
range
(
num_steps
):
qT
=
tl
.
load
(
QT_block_ptr
)
# Load m before computing qk to reduce pipeline stall.
offs_m
=
curr_m
+
tl
.
arange
(
0
,
BLOCK_M1
)
m
=
tl
.
load
(
M
+
offs_m
)
kqT
=
tl
.
dot
(
k
,
qT
)
if
alibi_slope
is
not
None
:
alibi_block
=
compute_alibi_block
(
alibi_slope
,
N_CTX
,
N_CTX
,
offs_m
,
offs_n
,
True
)
kqT
+=
alibi_block
*
1.44269504089
pT
=
tl
.
math
.
exp2
(
kqT
-
m
[
None
,
:])
# Autoregressive masking.
if
MASK
:
mask
=
(
offs_m
[
None
,
:]
>=
offs_n
[:,
None
])
pT
=
tl
.
where
(
mask
,
pT
,
0.0
)
do
=
tl
.
load
(
DO_block_ptr
)
# Compute dV.
ppT
=
pT
ppT
=
ppT
.
to
(
tl
.
float16
)
dv
+=
tl
.
dot
(
ppT
,
do
)
# D (= delta) is pre-divided by ds_scale.
Di
=
tl
.
load
(
D
+
offs_m
)
# Compute dP and dS.
dpT
=
tl
.
dot
(
v
,
tl
.
trans
(
do
))
dsT
=
pT
*
(
dpT
-
Di
[
None
,
:])
dsT
=
dsT
.
to
(
tl
.
float16
)
dk
+=
tl
.
dot
(
dsT
,
tl
.
trans
(
qT
))
# Increment pointers.
curr_m
+=
step_m
QT_block_ptr
=
tl
.
advance
(
QT_block_ptr
,
(
0
,
step_m
))
DO_block_ptr
=
tl
.
advance
(
DO_block_ptr
,
(
step_m
,
0
))
return
dk
,
dv
@
triton
.
jit
def
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
# shared by Q/K/V/DO.
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
:
tl
.
constexpr
,
BLOCK_N2
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# Filled in by the wrapper.
start_m
,
start_n
,
num_steps
,
MASK
:
tl
.
constexpr
):
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M2
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N2
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
KT_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_n
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N2
),
order
=
(
0
,
1
)
)
VT_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_d
,
stride_tok
),
offsets
=
(
0
,
start_n
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N2
),
order
=
(
0
,
1
)
)
# D (= delta) is pre-divided by ds_scale.
Di
=
tl
.
load
(
D
+
offs_m
)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl
.
static_assert
(
BLOCK_M2
%
BLOCK_N2
==
0
)
curr_n
=
start_n
step_n
=
BLOCK_N2
for
blk_idx
in
range
(
num_steps
):
kT
=
tl
.
load
(
KT_block_ptr
)
qk
=
tl
.
dot
(
q
,
kT
)
if
alibi_slope
is
not
None
:
alibi_block
=
compute_alibi_block
(
alibi_slope
,
N_CTX
,
N_CTX
,
offs_m
,
offs_n
)
qk
+=
alibi_block
*
1.44269504089
p
=
tl
.
math
.
exp2
(
qk
-
m
)
# Autoregressive masking.
if
MASK
:
offs_n
=
curr_n
+
tl
.
arange
(
0
,
BLOCK_N2
)
mask
=
(
offs_m
[:,
None
]
>=
offs_n
[
None
,
:])
p
=
tl
.
where
(
mask
,
p
,
0.0
)
# Compute dP and dS.
vT
=
tl
.
load
(
VT_block_ptr
)
dp
=
tl
.
dot
(
do
,
vT
).
to
(
tl
.
float32
)
ds
=
p
*
(
dp
-
Di
[:,
None
])
ds
=
ds
.
to
(
tl
.
float16
)
# Compute dQ.0.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq
+=
tl
.
dot
(
ds
,
tl
.
trans
(
kT
))
# Increment pointers.
curr_n
+=
step_n
KT_block_ptr
=
tl
.
advance
(
KT_block_ptr
,
(
0
,
step_n
))
VT_block_ptr
=
tl
.
advance
(
VT_block_ptr
,
(
0
,
step_n
))
return
dq
@
triton
.
jit
def
_attn_bwd
(
Q
,
K
,
V
,
sm_scale
,
alibi_slopes
,
DO
,
DQ
,
DK
,
DV
,
M
,
D
,
# shared by Q/K/V/DO.
stride_z
,
stride_h
,
stride_tok
,
stride_d
,
# H = 16, N_CTX = 1024
H
,
N_CTX
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M1
:
tl
.
constexpr
,
BLOCK_N1
:
tl
.
constexpr
,
BLOCK_M2
:
tl
.
constexpr
,
BLOCK_N2
:
tl
.
constexpr
,
BLK_SLICE_FACTOR
:
tl
.
constexpr
,
USE_ALIBI
:
tl
.
constexpr
):
LN2
:
tl
.
constexpr
=
0.6931471824645996
# = ln(2)
bhid
=
tl
.
program_id
(
2
)
off_chz
=
(
bhid
*
N_CTX
).
to
(
tl
.
int64
)
adj
=
(
stride_h
*
(
bhid
%
H
)
+
stride_z
*
(
bhid
//
H
)).
to
(
tl
.
int64
)
pid
=
tl
.
program_id
(
0
)
# offset pointers for batch/head
Q
+=
adj
K
+=
adj
V
+=
adj
DO
+=
adj
DQ
+=
adj
DK
+=
adj
DV
+=
adj
M
+=
off_chz
D
+=
off_chz
offs_k
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
start_n
=
pid
*
BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m
=
start_n
MASK_BLOCK_M1
:
tl
.
constexpr
=
BLOCK_M1
//
BLK_SLICE_FACTOR
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N1
)
dv
=
tl
.
zeros
([
BLOCK_N1
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
dk
=
tl
.
zeros
([
BLOCK_N1
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
k
=
tl
.
load
(
K_block_ptr
)
v
=
tl
.
load
(
V_block_ptr
)
if
USE_ALIBI
:
a_offset
=
bhid
alibi_slope
=
tl
.
load
(
alibi_slopes
+
a_offset
)
else
:
alibi_slope
=
None
# compute dK and dV for blocks close to the diagonal that need to be masked
num_steps
=
BLOCK_N1
//
MASK_BLOCK_M1
dk
,
dv
=
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
MASK_BLOCK_M1
,
BLOCK_N1
,
BLOCK_DMODEL
,
start_n
,
start_m
,
num_steps
,
MASK
=
True
)
# compute dK and dV for blocks that don't need masking further from the diagonal
start_m
+=
num_steps
*
MASK_BLOCK_M1
num_steps
=
(
N_CTX
-
start_m
)
//
BLOCK_M1
dk
,
dv
=
_bwd_kernel_dk_dv
(
dk
,
dv
,
Q
,
k
,
v
,
sm_scale
,
alibi_slope
,
DO
,
M
,
D
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M1
,
BLOCK_N1
,
BLOCK_DMODEL
,
start_n
,
start_m
,
num_steps
,
MASK
=
False
)
DV_block_ptrs
=
tl
.
make_block_ptr
(
base
=
DV
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
tl
.
store
(
DV_block_ptrs
,
dv
.
to
(
v
.
dtype
))
# Write back dK.
dk
*=
sm_scale
DK_block_ptrs
=
tl
.
make_block_ptr
(
base
=
DK
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_n
,
0
),
block_shape
=
(
BLOCK_N1
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
tl
.
store
(
DK_block_ptrs
,
dk
.
to
(
k
.
dtype
))
# THIS BLOCK DOES DQ:
start_m
=
pid
*
BLOCK_M2
end_n
=
start_m
+
BLOCK_M2
MASK_BLOCK_N2
:
tl
.
constexpr
=
BLOCK_N2
//
BLK_SLICE_FACTOR
offs_m
=
start_m
+
tl
.
arange
(
0
,
BLOCK_M2
)
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
DO_block_ptr
=
tl
.
make_block_ptr
(
base
=
DO
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
q
=
tl
.
load
(
Q_block_ptr
)
do
=
tl
.
load
(
DO_block_ptr
)
dq
=
tl
.
zeros
([
BLOCK_M2
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
m
=
tl
.
load
(
M
+
offs_m
)
m
=
m
[:,
None
]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps
=
BLOCK_M2
//
MASK_BLOCK_N2
dq
=
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
,
MASK_BLOCK_N2
,
BLOCK_DMODEL
,
start_m
,
end_n
-
num_steps
*
MASK_BLOCK_N2
,
num_steps
,
MASK
=
True
)
end_n
-=
num_steps
*
MASK_BLOCK_N2
# stage 2
num_steps
=
end_n
//
BLOCK_N2
dq
=
_bwd_kernel_dq
(
dq
,
q
,
K
,
V
,
do
,
m
,
D
,
alibi_slope
,
stride_tok
,
stride_d
,
H
,
N_CTX
,
BLOCK_M2
,
BLOCK_N2
,
BLOCK_DMODEL
,
start_m
,
end_n
-
num_steps
*
BLOCK_N2
,
num_steps
,
MASK
=
False
)
# Write back dQ.
DQ_block_ptr
=
tl
.
make_block_ptr
(
base
=
DQ
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_tok
,
stride_d
),
offsets
=
(
start_m
,
0
),
block_shape
=
(
BLOCK_M2
,
BLOCK_DMODEL
),
order
=
(
1
,
0
)
)
dq
*=
LN2
tl
.
store
(
DQ_block_ptr
,
dq
.
to
(
q
.
dtype
))
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
o
,
metadata
):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if
(
metadata
.
bias
is
not
None
):
assert
(
metadata
.
bias
.
numel
()
<
2
**
31
)
if
o
is
None
:
o
=
torch
.
empty_like
(
q
,
dtype
=
v
.
dtype
)
import
os
if
os
.
environ
.
get
(
"FLASH_ATTENTION_PRINT_PARAM"
,
"0"
)
==
"1"
:
print
(
f
"triton flash attention:
{
q
.
shape
=
}
,
{
k
.
shape
=
}
,
{
v
.
shape
}
,
{
o
.
shape
=
}
"
)
print
(
f
"triton flash attention:
{
q
.
stride
()
=
}
,
{
k
.
stride
()
=
}
,
{
v
.
stride
()
=
}
,
{
o
.
stride
()
=
}
"
)
print
(
f
"triton flash attention:
{
metadata
=
}
"
)
metadata
.
check_args
(
q
,
k
,
v
,
o
)
if
metadata
.
varlen
:
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
batch
=
metadata
.
num_contexts
q_strides
=
(
0
,
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
))
k_strides
=
(
0
,
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
))
v_strides
=
(
0
,
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
))
o_strides
=
(
0
,
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
))
else
:
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
q_strides
=
(
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
))
k_strides
=
(
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
))
v_strides
=
(
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
))
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
))
# Get closest power of 2 over or equal to 32.
padded_d_model
=
1
<<
(
head_size
-
1
).
bit_length
()
padded_d_model
=
max
(
padded_d_model
,
16
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
metadata
.
max_seqlens_q
,
META
[
'BLOCK_M'
]),
nheads_q
,
batch
)
# encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if
metadata
.
return_encoded_softmax
:
encoded_softmax
=
torch
.
zeros
((
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
k
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
else
:
encoded_softmax
=
None
M
=
torch
.
empty
((
batch
,
nheads_q
,
metadata
.
max_seqlens_q
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# Seed the RNG so we get reproducible results for testing.
philox_seed
=
0x1BF52
philox_offset
=
0x1D4B42
if
metadata
.
bias
is
not
None
:
bias_strides
=
(
metadata
.
bias
.
stride
(
0
),
metadata
.
bias
.
stride
(
1
),
metadata
.
bias
.
stride
(
2
),
metadata
.
bias
.
stride
(
3
))
else
:
bias_strides
=
(
0
,
0
,
0
,
0
)
if
metadata
.
alibi_slopes
is
not
None
:
alibi_strides
=
(
metadata
.
alibi_slopes
.
stride
(
0
),
metadata
.
alibi_slopes
.
stride
(
1
))
else
:
alibi_strides
=
(
0
,
0
)
attn_fwd
[
grid
](
q
,
k
,
v
,
metadata
.
bias
,
metadata
.
sm_scale
,
M
,
o
,
*
q_strides
,
*
k_strides
,
*
v_strides
,
*
o_strides
,
*
bias_strides
,
*
alibi_strides
,
metadata
.
cu_seqlens_q
,
metadata
.
cu_seqlens_k
,
dropout_p
=
metadata
.
dropout_p
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
alibi_slopes
=
metadata
.
alibi_slopes
,
HQ
=
nheads_q
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
metadata
.
max_seqlens_q
,
MAX_SEQLENS_K
=
metadata
.
max_seqlens_k
,
IS_CAUSAL
=
metadata
.
causal
,
VARLEN
=
metadata
.
varlen
,
BLOCK_DMODEL
=
padded_d_model
,
BIAS_TYPE
=
0
if
metadata
.
bias
is
None
else
1
,
USE_ALIBI
=
False
if
metadata
.
alibi_slopes
is
None
else
True
,
ENABLE_DROPOUT
=
metadata
.
dropout_p
>
0.0
,
RETURN_ENCODED_SOFTMAX
=
metadata
.
return_encoded_softmax
,
BATCH_SIZE
=
q
.
shape
[
0
]
)
if
os
.
environ
.
get
(
"FLASH_ATTENTION_PRINT_PARAM"
,
"0"
)
==
"1"
:
best_config
=
attn_fwd
.
get_best_config
()
print
(
f
"
{
best_config
.
kwargs
=
}
,
{
best_config
.
num_stages
=
}
,
{
best_config
.
num_warps
=
}
"
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
M
)
ctx
.
grid
=
grid
ctx
.
sm_scale
=
metadata
.
sm_scale
ctx
.
BLOCK_DMODEL
=
head_size
ctx
.
causal
=
metadata
.
causal
ctx
.
alibi_slopes
=
metadata
.
alibi_slopes
ctx
.
dropout_p
=
metadata
.
dropout_p
ctx
.
philox_seed
=
philox_seed
ctx
.
philox_offset
=
philox_offset
ctx
.
encoded_softmax
=
encoded_softmax
ctx
.
return_encoded_softmax
=
metadata
.
return_encoded_softmax
return
o
if
not
metadata
.
return_encoded_softmax
else
(
o
,
encoded_softmax
,
)
# S_dmask
@
staticmethod
def
backward
(
ctx
,
do
,
*
args
):
if
torch
.
version
.
hip
is
not
None
:
BLOCK
=
64
else
:
BLOCK
=
128
q
,
k
,
v
,
o
,
M
=
ctx
.
saved_tensors
import
os
if
os
.
environ
.
get
(
"TRITON_FLASHATTN_DEBUG"
,
"0"
)
==
"1"
:
print
(
f
"triton flash attention:
{
q
.
shape
=
}
,
{
k
.
shape
=
}
,
{
v
.
shape
}
,
{
o
.
shape
=
}
,
{
do
.
shape
=
}
"
)
print
(
f
"triton flash attention:
{
q
.
stride
()
=
}
,
{
k
.
stride
()
=
}
,
{
v
.
stride
()
=
}
,
{
o
.
stride
()
=
}
,
{
do
.
stride
()
}
"
)
# assert do.is_contiguous()
assert
q
.
stride
()
==
k
.
stride
()
==
v
.
stride
()
==
o
.
stride
()
==
do
.
stride
()
seqlen_q
=
q
.
shape
[
2
]
dq
=
torch
.
empty_like
(
q
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
BATCH
,
N_HEAD
,
N_CTX
=
q
.
shape
[:
3
]
PRE_BLOCK
=
128
NUM_WARPS
,
NUM_STAGES
=
4
,
1
BLOCK_M1
,
BLOCK_N1
,
BLOCK_M2
,
BLOCK_N2
=
32
,
64
,
64
,
32
BLK_SLICE_FACTOR
=
2
RCP_LN2
=
1.4426950408889634
# = 1.0 / ln(2)
arg_k
=
k
arg_k
=
arg_k
*
(
ctx
.
sm_scale
*
RCP_LN2
)
assert
N_CTX
%
PRE_BLOCK
==
0
delta
=
torch
.
empty_like
(
M
)
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
padded_head
=
(
Lk
!=
ctx
.
BLOCK_DMODEL
)
grid_preprocess
=
(
triton
.
cdiv
(
do
.
shape
[
2
],
BLOCK
),
do
.
shape
[
1
],
do
.
shape
[
0
])
_attn_bwd_preprocess
[
grid_preprocess
](
o
,
do
,
delta
,
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
do
.
stride
(
0
),
do
.
stride
(
1
),
do
.
stride
(
2
),
do
.
stride
(
3
),
seqlen_q
,
head_dim
=
Lk
,
BLOCK_M
=
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
N_CTX
,
META
[
'BLOCK_N1'
]),
1
,
BATCH
*
N_HEAD
)
_attn_bwd
[
grid
](
q
,
arg_k
,
v
,
ctx
.
sm_scale
,
ctx
.
alibi_slopes
,
do
,
dq
,
dk
,
dv
,
M
,
delta
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
N_HEAD
,
N_CTX
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
BLOCK_M1
=
BLOCK_M1
,
BLOCK_N1
=
BLOCK_N1
,
BLOCK_M2
=
BLOCK_M2
,
BLOCK_N2
=
BLOCK_N2
,
BLK_SLICE_FACTOR
=
BLK_SLICE_FACTOR
,
USE_ALIBI
=
False
if
ctx
.
alibi_slopes
is
None
else
True
,
)
return
dq
,
dk
,
dv
,
None
,
None
attention
=
_attention
.
apply
# flash_attn wrapper
def
input_helper
(
Z
,
HQ
,
HK
,
N_CTX_Q
,
N_CTX_K
,
D_HEAD
,
dtype
):
torch
.
manual_seed
(
20
)
# Initialize q, k, v
q
=
torch
.
randn
((
Z
,
HQ
,
N_CTX_Q
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
k
=
torch
.
randn
((
Z
,
HK
,
N_CTX_K
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
v
=
torch
.
randn
((
Z
,
HK
,
N_CTX_K
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
sm_scale
=
D_HEAD
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
sm_scale
)
input_metadata
.
max_seqlens_q
=
N_CTX_Q
input_metadata
.
max_seqlens_k
=
N_CTX_K
return
q
,
k
,
v
,
input_metadata
def
padding_bshd
(
t
):
# BSHD
batch
,
seqlen
,
nheads
,
dim
=
t
.
shape
t
=
torch
.
nn
.
functional
.
pad
(
t
.
reshape
(
batch
,
seqlen
,
nheads
*
dim
),
(
0
,
32
),
'constant'
,
0
)[:,:,:
-
32
].
reshape
(
batch
,
seqlen
,
nheads
,
dim
)
# pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:,:,:-32] # pad: dim+32
return
t
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
k
,
v
=
kv
[:,
:,
0
],
kv
[:,
:,
1
]
# batch_size, seqlen, 2, nheads_k, headdim
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
# pad
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
# trans bshd to bhsd
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
# trans bhsd to bshd
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
q
,
k
,
v
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
]
if
padding_input
:
k
,
v
=
(
padding_bshd
(
t
)
for
t
in
(
k
,
v
))
q
,
k
,
v
=
(
t
.
transpose
(
1
,
2
)
for
t
in
(
q
,
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
max_seqlens_q
=
q
.
shape
[
2
]
input_metadata
.
max_seqlens_k
=
k
.
shape
[
2
]
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
).
transpose
(
1
,
2
)
# varlen flash_attn
def
varlen_input_helper
(
Z
,
HQ
,
HK
,
N_CTX_Q
,
N_CTX_K
,
D_HEAD
,
dtype
,
causal
):
torch
.
manual_seed
(
20
)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs
max_seqlens_q
=
N_CTX_Q
//
Z
max_seqlens_k
=
N_CTX_K
//
Z
seqlens_q
=
torch
.
randint
(
1
,
max_seqlens_q
+
1
,
(
Z
,),
dtype
=
torch
.
int32
)
seqlens_k
=
torch
.
randint
(
1
,
max_seqlens_k
+
1
,
(
Z
,),
dtype
=
torch
.
int32
)
max_seqlens_q
=
torch
.
max
(
seqlens_q
).
item
()
max_seqlens_k
=
torch
.
max
(
seqlens_k
).
item
()
# Calculate cumulative sequence lengths
cu_seqlens_q
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
seqlens_q
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)])
cu_seqlens_k
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
seqlens_k
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)])
cu_seqlens_q
=
cu_seqlens_q
.
to
(
device
=
"cuda"
)
cu_seqlens_k
=
cu_seqlens_k
.
to
(
device
=
"cuda"
)
# -1 because the last entry of cu_seqlens_q specifies the end of the last seq
num_ctxs
=
len
(
cu_seqlens_q
)
-
1
# Initialize q, k, v with variable lengths
total_q
=
cu_seqlens_q
[
-
1
].
item
()
total_k
=
cu_seqlens_k
[
-
1
].
item
()
q
=
torch
.
randn
((
total_q
,
HQ
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
k
=
torch
.
randn
((
total_k
,
HK
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
v
=
torch
.
randn
((
total_k
,
HK
,
D_HEAD
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.
,
std
=
0.5
).
requires_grad_
()
sm_scale
=
D_HEAD
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
sm_scale
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
if
causal
:
input_metadata
.
need_causal
()
return
q
,
k
,
v
,
input_metadata
def
padding_thd
(
t
):
# THD
total_seqlen
,
nheads
,
dim
=
t
.
shape
t
=
torch
.
nn
.
functional
.
pad
(
t
.
reshape
(
total_seqlen
,
nheads
*
dim
),
(
0
,
32
),
'constant'
,
0
)[:,:
-
32
].
reshape
(
total_seqlen
,
nheads
,
dim
)
# pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:-32] # pad: dim+32
return
t
def
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
q
,
k
,
v
=
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
]
# total_seqlen, 3, nheads, dim
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
# pad
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens
,
cu_seqlens
)
input_metadata
.
max_seqlens_q
=
max_seqlens
input_metadata
.
max_seqlens_k
=
max_seqlens
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
def
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
k
,
v
=
kv
[:,
0
],
kv
[:,
1
]
# total_seqlen, 2, nheads, dim
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
if
padding_input
:
k
,
v
=
(
padding_thd
(
t
)
for
t
in
(
k
,
v
))
softmax_scale
=
softmax_scale
if
softmax_scale
else
q
.
shape
[
-
1
]
**-
0.5
input_metadata
=
MetaData
(
sm_scale
=
softmax_scale
,
causal
=
causal
,
dropout_p
=
dropout_p
,
return_encoded_softmax
=
return_attn_probs
)
input_metadata
.
set_varlen_params
(
cu_seqlens_q
,
cu_seqlens_k
)
input_metadata
.
max_seqlens_q
=
max_seqlens_q
input_metadata
.
max_seqlens_k
=
max_seqlens_k
return
_attention
.
apply
(
q
,
k
,
v
,
None
,
input_metadata
)
# legacy interface
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlens
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
padding_input
=
False
):
return
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
padding_input
)
vllm/attention/ops/paged_attn.py
View file @
7462218e
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
import
vllm.envs
as
envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
...
@@ -122,26 +123,53 @@ class PagedAttention:
...
@@ -122,26 +123,53 @@ class PagedAttention:
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
output
,
print
(
"PA V1 SIZE:"
)
query
,
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
key_cache
,
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
value_cache
,
num_kv_heads
,
if
envs
.
VLLM_USE_OPT_OP
:
scale
,
ops
.
paged_attention_v1_opt
(
block_tables
,
output
,
seq_lens
,
query
,
block_size
,
key_cache
,
max_seq_len
,
value_cache
,
alibi_slopes
,
num_kv_heads
,
kv_cache_dtype
,
scale
,
kv_scale
,
block_tables
,
tp_rank
,
seq_lens
,
blocksparse_local_blocks
,
block_size
,
blocksparse_vert_stride
,
max_seq_len
,
blocksparse_block_size
,
alibi_slopes
,
blocksparse_head_sliding_step
,
kv_cache_dtype
,
)
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
assert
_PARTITION_SIZE
%
block_size
==
0
...
@@ -156,29 +184,61 @@ class PagedAttention:
...
@@ -156,29 +184,61 @@ class PagedAttention:
device
=
output
.
device
,
device
=
output
.
device
,
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
exp_sums
,
print
(
"PA V2 SIZE:"
)
max_logits
,
print
(
f
"exp_sums.shape =
{
exp_sums
.
shape
}
, max_logits.shape =
{
max_logits
.
shape
}
, tmp_output.shape =
{
tmp_output
.
shape
}
"
)
tmp_output
,
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
query
,
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
key_cache
,
value_cache
,
if
envs
.
VLLM_USE_OPT_OP
:
num_kv_heads
,
ops
.
paged_attention_v2_opt
(
scale
,
output
,
block_tables
,
exp_sums
,
seq_lens
,
max_logits
,
block_size
,
tmp_output
,
max_seq_len
,
query
,
alibi_slopes
,
key_cache
,
kv_cache_dtype
,
value_cache
,
kv_scale
,
num_kv_heads
,
tp_rank
,
scale
,
blocksparse_local_blocks
,
block_tables
,
blocksparse_vert_stride
,
seq_lens
,
blocksparse_block_size
,
block_size
,
blocksparse_head_sliding_step
,
max_seq_len
,
)
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
return
output
return
output
@
staticmethod
@
staticmethod
...
...
vllm/attention/ops/prefix_prefill.py
View file @
7462218e
...
@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window
=
None
):
sliding_window
=
None
):
cap
=
torch
.
cuda
.
get_device_capability
()
cap
=
torch
.
cuda
.
get_device_capability
()
BLOCK
=
128
if
cap
[
0
]
>=
8
else
64
BLOCK
=
32
if
cap
[
0
]
>=
8
else
32
# shape constraints
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
...
@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if
sliding_window
is
None
or
sliding_window
<=
0
:
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
sliding_window
=
0
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
4
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
...
...
vllm/benchmark_throughput.py
0 → 100644
View file @
7462218e
"""Benchmark offline inference throughput."""
import
argparse
import
json
import
random
import
time
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
from
vllm.inputs
import
PromptStrictInputs
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
],
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_dataset
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
return
filtered_dataset
def
run_vllm
(
warmup_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
download_dir
:
Optional
[
str
]
=
None
,
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
)
# Add the requests to the engine.
prompts
=
[]
sampling_params
=
[]
for
prompt
,
_
,
output_len
in
requests
:
prompts
.
append
(
prompt
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
# warmup
warmup_prompts
=
[]
warmup_sampling_params
=
[]
for
prompt
,
_
,
output_len
in
warmup_requests
:
warmup_prompts
.
append
(
prompt
)
warmup_sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
warmup_prompts
,
warmup_sampling_params
,
use_tqdm
=
True
)
# dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.num_prompts,
# args.input_len))
# dummy_inputs: List[PromptStrictInputs] = [{
# "prompt_token_ids": batch
# } for batch in dummy_prompt_token_ids.tolist()]
# def run_to_completion():
# llm.generate(dummy_inputs,
# sampling_params=sampling_params,
# use_tqdm=False)
# print("Warming up...")
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion()
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
return
end
-
start
def
run_hf
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
n
:
int
,
use_beam_search
:
bool
,
max_batch_size
:
int
,
trust_remote_code
:
bool
,
)
->
float
:
assert
not
use_beam_search
llm
=
AutoModelForCausalLM
.
from_pretrained
(
model
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
trust_remote_code
)
if
llm
.
config
.
model_type
==
"llama"
:
# To enable padding in the HF backend.
tokenizer
.
pad_token
=
tokenizer
.
eos_token
llm
=
llm
.
cuda
()
pbar
=
tqdm
(
total
=
len
(
requests
))
start
=
time
.
perf_counter
()
batch
:
List
[
str
]
=
[]
max_prompt_len
=
0
max_output_len
=
0
for
i
in
range
(
len
(
requests
)):
prompt
,
prompt_len
,
output_len
=
requests
[
i
]
# Add the prompt to the batch.
batch
.
append
(
prompt
)
max_prompt_len
=
max
(
max_prompt_len
,
prompt_len
)
max_output_len
=
max
(
max_output_len
,
output_len
)
if
len
(
batch
)
<
max_batch_size
and
i
!=
len
(
requests
)
-
1
:
# Check if we can add more requests to the batch.
_
,
next_prompt_len
,
next_output_len
=
requests
[
i
+
1
]
if
(
max
(
max_prompt_len
,
next_prompt_len
)
+
max
(
max_output_len
,
next_output_len
))
<=
2048
:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
padding
=
True
).
input_ids
llm_outputs
=
llm
.
generate
(
input_ids
=
input_ids
.
cuda
(),
do_sample
=
not
use_beam_search
,
num_return_sequences
=
n
,
temperature
=
1.0
,
top_p
=
1.0
,
use_cache
=
True
,
max_new_tokens
=
max_output_len
,
)
# Include the decoding time.
tokenizer
.
batch_decode
(
llm_outputs
,
skip_special_tokens
=
True
)
pbar
.
update
(
len
(
batch
))
# Clear the batch.
batch
=
[]
max_prompt_len
=
0
max_output_len
=
0
end
=
time
.
perf_counter
()
return
end
-
start
def
run_mii
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tensor_parallel_size
:
int
,
output_len
:
int
,
)
->
float
:
from
mii
import
client
,
serve
llm
=
serve
(
model
,
tensor_parallel
=
tensor_parallel_size
)
prompts
=
[
prompt
for
prompt
,
_
,
_
in
requests
]
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
max_new_tokens
=
output_len
)
end
=
time
.
perf_counter
()
client
=
client
(
model
)
client
.
terminate_server
()
return
end
-
start
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
warmup_prompt
=
"hi"
*
10
warmup_requests
=
[(
warmup_prompt
,
10
,
10
)
for
_
in
range
(
1
)]
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
else
:
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
,
args
.
output_len
)
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
download_dir
)
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
args
.
use_beam_search
,
args
.
hf_max_batch_size
,
args
.
trust_remote_code
)
elif
args
.
backend
==
"mii"
:
elapsed_time
=
run_mii
(
requests
,
args
.
model
,
args
.
tensor_parallel_size
,
args
.
output_len
)
else
:
raise
ValueError
(
f
"Unknown backend:
{
args
.
backend
}
"
)
total_num_tokens
=
sum
(
prompt_len
+
output_len
for
_
,
prompt_len
,
output_len
in
requests
)
if
args
.
dataset
is
None
:
total_out_tokens
=
args
.
output_len
*
args
.
num_prompts
else
:
total_out_tokens
=
sum
(
output_len
for
_
,
_
,
output_len
in
requests
)
print
(
f
"Latency:
{
elapsed_time
:.
2
f
}
s"
)
print
(
f
"All Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
print
(
f
"Generate Throughput:
{
total_out_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"elapsed_time"
:
elapsed_time
,
"num_requests"
:
len
(
requests
),
"total_num_tokens"
:
total_num_tokens
,
"requests_per_second"
:
len
(
requests
)
/
elapsed_time
,
"tokens_per_second"
:
total_num_tokens
/
elapsed_time
,
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
choices
=
[
"vllm"
,
"hf"
,
"mii"
],
default
=
"vllm"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
None
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
,
help
=
"Input prompt length for each request"
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
None
,
help
=
"Output length for each request. Overrides the "
"output length from the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
'--num-iters-warmup'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run for warmup.'
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--hf-max-batch-size"
,
type
=
int
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'trust remote code from huggingface'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
default
=
None
,
help
=
'Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'auto'
,
'half'
,
'float16'
,
'bfloat16'
,
'float'
,
'float32'
],
help
=
'data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.9
,
help
=
'the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.'
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
,
help
=
"enforce eager execution"
)
parser
.
add_argument
(
'--kv-cache-dtype'
,
type
=
str
,
choices
=
[
'auto'
,
'fp8'
,
'fp8_e5m2'
,
'fp8_e4m3'
],
default
=
"auto"
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"cpu"
],
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"enable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
None
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--output-json'
,
type
=
str
,
default
=
None
,
help
=
'Path to save the throughput results in JSON format.'
)
parser
.
add_argument
(
'--distributed-executor-backend'
,
choices
=
[
'ray'
,
'mp'
],
default
=
None
,
help
=
'Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.'
)
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
if
args
.
dataset
is
None
:
assert
args
.
input_len
is
not
None
assert
args
.
output_len
is
not
None
else
:
assert
args
.
input_len
is
None
if
args
.
backend
==
"vllm"
:
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
elif
args
.
backend
==
"hf"
:
if
args
.
hf_max_batch_size
is
None
:
raise
ValueError
(
"HF max batch size is required for HF backend."
)
if
args
.
quantization
is
not
None
:
raise
ValueError
(
"Quantization is only for vLLM backend."
)
elif
args
.
backend
==
"mii"
:
if
args
.
dtype
!=
"auto"
:
raise
ValueError
(
"dtype must be auto for MII backend."
)
if
args
.
n
!=
1
:
raise
ValueError
(
"n must be 1 for MII backend."
)
if
args
.
use_beam_search
:
raise
ValueError
(
"Beam search is not supported for MII backend."
)
if
args
.
quantization
is
not
None
:
raise
ValueError
(
"Quantization is only for vLLM backend."
)
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
if
args
.
tokenizer
!=
args
.
model
:
raise
ValueError
(
"Tokenizer must be the same as the model for MII "
"backend."
)
main
(
args
)
\ No newline at end of file
vllm/config.py
View file @
7462218e
...
@@ -173,7 +173,7 @@ class ModelConfig:
...
@@ -173,7 +173,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
@@ -279,6 +279,12 @@ class ModelConfig:
...
@@ -279,6 +279,12 @@ class ModelConfig:
return
self
.
hf_text_config
.
hidden_size
return
self
.
hf_text_config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
# TODO remove hard code
if
hasattr
(
self
.
hf_text_config
,
"model_type"
)
and
self
.
hf_text_config
.
model_type
==
'deepseek_v2'
:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return
256
if
hasattr
(
self
.
hf_text_config
,
"head_dim"
):
if
hasattr
(
self
.
hf_text_config
,
"head_dim"
):
return
self
.
hf_text_config
.
head_dim
return
self
.
hf_text_config
.
head_dim
# FIXME(woosuk): This may not be true for all models.
# FIXME(woosuk): This may not be true for all models.
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
7462218e
...
@@ -315,4 +315,4 @@ class CustomAllreduce:
...
@@ -315,4 +315,4 @@ class CustomAllreduce:
self
.
_ptr
=
0
self
.
_ptr
=
0
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
\ No newline at end of file
vllm/engine/llm_engine.py
View file @
7462218e
...
@@ -232,76 +232,91 @@ class LLMEngine:
...
@@ -232,76 +232,91 @@ class LLMEngine:
load_config
=
load_config
,
load_config
=
load_config
,
)
)
if
not
self
.
model_config
.
embedding_mode
:
init_success
=
False
self
.
_initialize_kv_caches
()
try
:
if
not
self
.
model_config
.
embedding_mode
:
# If usage stat is enabled, collect relevant info.
self
.
_initialize_kv_caches
()
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
# If usage stat is enabled, collect relevant info.
get_architecture_class_name
)
if
is_usage_stats_enabled
():
usage_message
.
report_usage
(
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
(
model_config
),
get_architecture_class_name
)
usage_context
,
usage_message
.
report_usage
(
extra_kvs
=
{
get_architecture_class_name
(
model_config
),
# Common configuration
usage_context
,
"dtype"
:
extra_kvs
=
{
str
(
model_config
.
dtype
),
# Common configuration
"tensor_parallel_size"
:
"dtype"
:
parallel_config
.
tensor_parallel_size
,
str
(
model_config
.
dtype
),
"block_size"
:
"tensor_parallel_size"
:
cache_config
.
block_size
,
parallel_config
.
tensor_parallel_size
,
"gpu_memory_utilization"
:
"block_size"
:
cache_config
.
gpu_memory_utilization
,
cache_config
.
block_size
,
"gpu_memory_utilization"
:
# Quantization
cache_config
.
gpu_memory_utilization
,
"quantization"
:
model_config
.
quantization
,
# Quantization
"kv_cache_dtype"
:
"quantization"
:
cache_config
.
cache_dtype
,
model_config
.
quantization
,
"kv_cache_dtype"
:
# Feature flags
cache_config
.
cache_dtype
,
"enable_lora"
:
bool
(
lora_config
),
# Feature flags
"enable_prefix_caching"
:
"enable_lora"
:
cache_config
.
enable_prefix_caching
,
bool
(
lora_config
),
"enforce_eager"
:
"enable_prefix_caching"
:
model_config
.
enforce_eager
,
cache_config
.
enable_prefix_caching
,
"disable_custom_all_reduce"
:
"enforce_eager"
:
parallel_config
.
disable_custom_all_reduce
,
model_config
.
enforce_eager
,
})
"disable_custom_all_reduce"
:
parallel_config
.
disable_custom_all_reduce
,
if
self
.
tokenizer
:
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
if
self
.
tokenizer
:
self
.
tokenizer
.
ping
()
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
# Create the scheduler.
self
.
tokenizer
.
ping
()
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# Create the scheduler.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# Metric Logging.
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
if
self
.
log_stats
:
self
.
stat_logger
=
StatLogger
(
# Metric Logging.
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
if
self
.
log_stats
:
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
self
.
stat_logger
=
StatLogger
(
max_model_len
=
self
.
model_config
.
max_model_len
)
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
)
# Create sequence output processor, e.g. for beam search or
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
# speculative decoding.
self
.
output_processor
=
(
tokenizer_group
=
self
.
get_tokenizer_group
()
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
def
get_tokenizer_for_seq
(
self
,
self
.
detokenizer
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
self
.
scheduler
,
return
tokenizer_group
.
get_lora_tokenizer
(
self
.
seq_counter
,
sequence
.
lora_request
)
self
.
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
# Create sequence output processor, e.g. for beam search or
self
.
scheduler_config
.
max_model_len
,
# speculative decoding.
self
.
get_tokenizer_for_seq
,
self
.
output_processor
=
(
),
SequenceGroupOutputProcessor
.
create_output_processor
(
))
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
=
StopChecker
(
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
),
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -393,10 +408,10 @@ class LLMEngine:
...
@@ -393,10 +408,10 @@ class LLMEngine:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
None
)
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
None
)
def
get_tokenizer_for_seq
(
self
,
#
def get_tokenizer_for_seq(self,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
#
sequence: Sequence) -> "PreTrainedTokenizer":
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
#
return self.get_tokenizer_group().get_lora_tokenizer(
sequence
.
lora_request
)
#
sequence.lora_request)
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
)
->
BaseTokenizerGroup
:
init_kwargs
=
dict
(
init_kwargs
=
dict
(
...
@@ -785,7 +800,8 @@ class LLMEngine:
...
@@ -785,7 +800,8 @@ class LLMEngine:
# Log stats.
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
do_log_stats
(
scheduler_outputs
,
output
)
if
not
request_outputs
:
# if not request_outputs:
if
not
self
.
has_unfinished_requests
():
# Stop the execute model loop in parallel workers until there are
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# torch.distributed ops which may otherwise timeout, and unblocks
...
...
vllm/envs.py
View file @
7462218e
...
@@ -9,6 +9,9 @@ if TYPE_CHECKING:
...
@@ -9,6 +9,9 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
...
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_
XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_
FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN"
:
"VLLM_USE_TRITON_FLASH_ATTN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
...
@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# timeout for each iteration in the engine
# timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
:
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"
6
0"
)),
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"
12
0"
)),
# API key for VLLM API server
# API key for VLLM API server
"VLLM_API_KEY"
:
"VLLM_API_KEY"
:
...
@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD"
:
"VLLM_WORKER_MULTIPROC_METHOD"
:
lambda
:
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
),
lambda
:
os
.
getenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"65536"
)),
# Timeout for fetching images when serving multimodal models
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
"VLLM_IMAGE_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
@@ -236,4 +252,4 @@ def __getattr__(name):
...
@@ -236,4 +252,4 @@ def __getattr__(name):
def
__dir__
():
def
__dir__
():
return
list
(
environment_variables
.
keys
())
return
list
(
environment_variables
.
keys
())
\ No newline at end of file
vllm/executor/multiproc_worker_utils.py
View file @
7462218e
...
@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
...
@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
True
)
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
self
.
result_queue
=
mp
.
Queue
()
self
.
result_queue
=
mp
.
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
...
@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
...
@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
True
)
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
self
.
workers
=
workers
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
result_handler
=
result_handler
self
.
_close
=
False
self
.
_close
=
False
...
@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
...
@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self
.
_close
=
True
self
.
_close
=
True
# Kill / cleanup all workers
# Kill / cleanup all workers
for
worker
in
self
.
workers
:
# for worker in self.workers:
process
=
worker
.
process
# process = worker.process
if
process
.
sentinel
in
dead_sentinels
:
# if process.sentinel in dead_sentinels:
process
.
join
(
JOIN_TIMEOUT_S
)
# process.join(JOIN_TIMEOUT_S)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
# if process.exitcode is not None and process.exitcode != 0:
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
# logger.error("Worker %s pid %s died, exit code: %s",
process
.
name
,
process
.
pid
,
process
.
exitcode
)
# process.name, process.pid, process.exitcode)
if
not
sys
.
is_finalizing
():
# Kill / cleanup all workers
died_count
=
0
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
died_count
+=
1
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
if
died_count
<
len
(
self
.
workers
):
logger
.
info
(
"Killing remaining local vLLM worker processes"
)
# Cleanup any remaining workers
# Cleanup any remaining workers
logger
.
info
(
"Killing local vLLM worker processes"
)
#
logger.info("Killing local vLLM worker processes")
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
# Must be done after worker task queues are all closed
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment