Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
badaff2d
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "91b3d190ae86aeec185ec1da663c0dda5da30545"
Commit
badaff2d
authored
Mar 09, 2026
by
wanghl6
Browse files
添加dspk prefill atten前FUSE_CAT_AND_CAST_FP8
parent
004a1ef4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
53 deletions
+75
-53
vllm/envs.py
vllm/envs.py
+4
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+71
-53
No files found.
vllm/envs.py
View file @
badaff2d
...
@@ -219,6 +219,7 @@ if TYPE_CHECKING:
...
@@ -219,6 +219,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION
:
bool
=
False
VLLM_ENABLE_SHARED_EXPERTS_FUSION
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_USE_FUSED_DTBMM
:
bool
=
False
VLLM_USE_FUSED_DTBMM
:
bool
=
False
VLLM_FUSE_CAT_AND_CAST_FP8
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM"
:
"VLLM_USE_FUSED_DTBMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_DTBMM"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_DTBMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_FUSE_CAT_AND_CAST_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FUSE_CAT_AND_CAST_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
badaff2d
...
@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
use_flash_fp8_arch
=
(
\
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
\
if
envs
.
VLLM_USE_OPT_CAT
:
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
if
k_nope
.
shape
[
0
]
>
1024
:
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
else
:
if
use_fused_fp8_op
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
from
lightop
import
op
dim
=-
1
)
q
,
k
,
v
=
op
.
ds_fused_qkv_cast_fp8
(
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k_nope
,
v
=
kv_nope
\
dim
=-
1
)
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
if
k_nope
.
shape
[
0
]
>
1024
:
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe_expanded
,
dim
=
2
)
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
else
:
q_descale
=
q_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe_expanded
),
dim
=-
1
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
else
:
v_descale
=
v_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe_expanded
),
dim
=-
1
)
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
if
use_flash_fp8_arch
:
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
not
use_fused_fp8_op
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
attn_output
,
attn_softmax_lse
=
\
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
q
=
q
,
...
@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
use_flash_fp8_arch
=
(
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
\
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
if
envs
.
VLLM_USE_OPT_CAT
:
)
if
k_nope
.
shape
[
0
]
>
1024
:
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
if
use_fused_fp8_op
:
dim
=
2
)
from
lightop
import
op
else
:
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
q
,
k
,
v
=
op
.
ds_fused_qkv_cast_fp8
(
dim
=-
1
)
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_OPT_CAT
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
if
k_nope
.
shape
[
0
]
>
1024
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
else
:
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
else
:
k_descale
=
k_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
use_flash_fp8_arch
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
q_descale
=
None
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
k_descale
=
None
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
v_descale
=
None
if
not
use_fused_fp8_op
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
decode_q
=
q
[:
num_decode_tokens
]
decode_q
=
q
[:
num_decode_tokens
]
...
@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False
,
False
,
1e-6
,
1e-6
,
)
)
if
has_prefill
:
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
prefill_k_c_normed
=
key_normed
[:
num_actual_toks
,
...]
prefill_k_c_normed
=
key_normed
[:
num_actual_toks
,
...]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment