Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
43546076
Commit
43546076
authored
Dec 26, 2025
by
zhuwenwen
Browse files
add VLLM_USE_FLASH_ATTN_FP8 to support fa fp8
parent
1663f34c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
27 deletions
+101
-27
setup.py
setup.py
+2
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+94
-25
No files found.
setup.py
View file @
43546076
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt
2
.'
+
sha
[:
7
]
version
=
'das.opt
3
.'
+
sha
[:
7
]
else
:
else
:
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt
2
'
version
=
'das.opt
3
'
# dtk version
# dtk version
...
...
vllm/envs.py
View file @
43546076
...
@@ -145,6 +145,7 @@ if TYPE_CHECKING:
...
@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_FLASH_ATTN_FP8
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
...
@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_OPT_MLA"
:
"VLLM_USE_TRITON_OPT_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_OPT_MLA"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_OPT_MLA"
,
"0"
))),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_ATTN_FP8"
,
"0"
))),
# If set, vLLM will use FLASH MLA attention optimizations.
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA"
:
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
...
...
vllm/v1/attention/backends/mla/common.py
View file @
43546076
...
@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q
,
q
,
k
,
k
,
v
,
v
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
return_softmax_lse
=
False
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
softmax_scale
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
maybe_padded_v
,
v
=
maybe_padded_v
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -978,19 +984,51 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -978,19 +984,51 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
\
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
self
.
_flash_attn_varlen_diff_headdims
(
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
q
=
q
,
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
k
,
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v
=
v
,
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
k_descale
=
k_descale
.
expand
(
descale_shape
)
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
max_seqlen_k
=
prefill_metadata
.
chunked_context
.
max_seq_lens
[
i
],
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
softmax_scale
=
self
.
scale
,
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
causal
=
False
,
# Context is unmasked
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
return_softmax_lse
=
True
,
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
chunked_context
.
max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
chunked_context
.
max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
return_softmax_lse
=
True
,
)
if
output
is
None
:
if
output
is
None
:
output
=
attn_output
output
=
attn_output
...
@@ -1043,18 +1081,49 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1043,18 +1081,49 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else
:
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
q
=
q
,
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
k
,
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v
=
v
,
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query_len
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
softmax_scale
=
self
.
scale
,
causal
=
True
,
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
return_softmax_lse
=
has_context
,
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
has_context
,
)
else
:
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
return_softmax_lse
=
has_context
,
)
if
has_context
:
if
has_context
:
suffix_output
,
suffix_lse
=
output
suffix_output
,
suffix_lse
=
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment