Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac28ab22
Commit
ac28ab22
authored
Feb 04, 2026
by
zhuwenwen
Browse files
[perf] add VLLM_USE_FLASH_ATTN_FP8 to use fa fp8 attention
parent
5fe03549
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
24 deletions
+85
-24
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+80
-24
No files found.
vllm/envs.py
View file @
ac28ab22
...
@@ -258,6 +258,7 @@ if TYPE_CHECKING:
...
@@ -258,6 +258,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT
:
int
|
None
=
None
VLLM_OPTEST_URLS_PORT
:
int
|
None
=
None
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_FP8
:
bool
=
False
VLLM_USE_QUERY_QUANT
:
bool
=
False
VLLM_USE_QUERY_QUANT
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
...
@@ -1685,6 +1686,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1685,6 +1686,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_ATTN_FP8"
,
"0"
))),
# flag to control if vllm should use q quant
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT"
:
"VLLM_USE_QUERY_QUANT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_QUERY_QUANT"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_QUERY_QUANT"
,
"False"
).
lower
()
in
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
ac28ab22
...
@@ -1499,18 +1499,46 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1499,18 +1499,46 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def
_run_prefill_new_tokens_fa
(
def
_run_prefill_new_tokens_fa
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
):
return
self
.
_flash_attn_varlen_diff_headdims
(
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
q
=
q
,
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
k
,
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v
=
v
,
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
cu_seqlens_q
=
prefill
.
query_start_loc
,
descale_shape
=
(
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
cu_seqlens_k
=
prefill
.
query_start_loc
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
max_seqlen_q
=
prefill
.
max_query_len
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
max_seqlen_k
=
prefill
.
max_query_len
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
softmax_scale
=
self
.
scale
,
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
causal
=
True
,
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
return_softmax_lse
=
return_softmax_lse
,
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
)
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
query_start_loc
,
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
return_softmax_lse
,
)
else
:
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
query_start_loc
,
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
return_softmax_lse
,
)
def
_run_prefill_new_tokens_fi
(
def
_run_prefill_new_tokens_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
...
@@ -1558,18 +1586,46 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1558,18 +1586,46 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
):
assert
prefill
.
chunked_context
is
not
None
assert
prefill
.
chunked_context
is
not
None
return
self
.
_flash_attn_varlen_diff_headdims
(
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
q
=
q
,
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
k
,
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v
=
v
,
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
cu_seqlens_q
=
prefill
.
query_start_loc
,
descale_shape
=
(
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
cu_seqlens_k
=
prefill
.
chunked_context
.
cu_seq_lens
[
chunk_idx
],
q_descale
=
q_descale
.
expand
(
descale_shape
)
max_seqlen_q
=
prefill
.
max_query_len
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
max_seqlen_k
=
prefill
.
chunked_context
.
max_seq_lens
[
chunk_idx
],
v_descale
=
v_descale
.
expand
(
descale_shape
)
softmax_scale
=
self
.
scale
,
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
causal
=
False
,
# Context is unmasked
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
return_softmax_lse
=
True
,
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
)
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
chunked_context
.
cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
chunked_context
.
max_seq_lens
[
chunk_idx
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
)
else
:
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
chunked_context
.
cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
chunked_context
.
max_seq_lens
[
chunk_idx
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
def
_run_prefill_context_chunk_fi
(
def
_run_prefill_context_chunk_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment