Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d7bee8b6
Commit
d7bee8b6
authored
Mar 12, 2026
by
wanghl6
Browse files
feat: 元宝 prefill融合算子优化
parent
d761561a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
220 additions
and
111 deletions
+220
-111
vllm/_custom_ops.py
vllm/_custom_ops.py
+20
-7
vllm/envs.py
vllm/envs.py
+12
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+187
-102
No files found.
vllm/_custom_ops.py
View file @
d7bee8b6
...
...
@@ -2184,6 +2184,7 @@ def gather_cache(src_cache: torch.Tensor,
)
->
None
:
#支持"kv cache fp8" 临时方案,带dtype的gather_cache在vllm0.10后会实现。
if
kv_dtype
==
"fp8"
or
kv_dtype
==
"fp8_e5m2"
or
kv_dtype
==
"fp8_e4m3"
:
if
not
envs
.
VLLM_FUSED_GATHER_CACHE_CONVERT_FP8
:
dst_fp8
=
torch
.
empty
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
#convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
...
...
@@ -2191,6 +2192,18 @@ def gather_cache(src_cache: torch.Tensor,
#dst_fp8->bf16
# convert_fp8(dst, dst_fp8, scale, kv_dtype)
convert_fp8
(
dst
,
dst_fp8
,
1.0
,
kv_dtype
)
else
:
from
lightop
import
op
op
.
gather_convert_fp8_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
scale
,
kv_dtype
,
seq_starts
)
else
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
...
...
vllm/envs.py
View file @
d7bee8b6
...
...
@@ -219,7 +219,9 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_USE_FUSED_DTBMM
:
bool
=
False
VLLM_FUSE_CAT_AND_CAST_FP8
:
bool
=
False
VLLM_FUSED_GATHER_CACHE_CONVERT_FP8
:
bool
=
False
VLLM_FUSED_RN_ROPE_INT8_QUANT
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
...
...
@@ -1404,6 +1406,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_DTBMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_FUSE_CAT_AND_CAST_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FUSE_CAT_AND_CAST_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_FUSED_GATHER_CACHE_CONVERT_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FUSED_GATHER_CACHE_CONVERT_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_FUSED_RN_ROPE_INT8_QUANT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FUSED_RN_ROPE_INT8_QUANT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
d7bee8b6
...
...
@@ -469,7 +469,7 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
if
(
envs
.
USE_FUSED_RMS_QUANT
or
envs
.
VLLM_FUSED_RN_ROPE_INT8_QUANT
)
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_zp
=
None
x_q
,
x_scale
=
input_quant_args
...
...
vllm/v1/attention/backends/mla/common.py
View file @
d7bee8b6
...
...
@@ -1015,6 +1015,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
iters
=
len
(
prefill_metadata
.
chunked_context
.
seq_tot
)
workspace
=
prefill_metadata
.
chunked_context
.
workspace
use_flash_fp8_arch
=
(
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
for
i
in
range
(
iters
):
toks
=
prefill_metadata
.
chunked_context
.
seq_tot
[
i
]
...
...
@@ -1029,62 +1035,64 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
kv_scale
,
)
kv_c_normed
=
workspace
[:
toks
]
\
[...,
:
self
.
kv_lora_rank
]
k_pe
=
workspace
[:
toks
]
\
[...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
kv_c_normed
=
workspace
[:
toks
][...,
:
self
.
kv_lora_rank
]
k_pe
=
workspace
[:
toks
][...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
use_fused_fp8_op
:
from
lightop
import
op
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
q_attn
,
k_attn
,
v_attn
=
op
.
ds_fused_qkv_cast_fp8
(
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
k_nope
,
v_nope
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_OPT_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
k_cat
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k_cat
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
q_descale
=
q_descale
.
expand
(
descale_shape
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
v_descale
=
v_descale
.
expand
(
descale_shape
)
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
k_cat
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
if
use_flash_fp8_arch
:
q_attn
=
q
.
to
(
torch
.
float8_e4m3fn
)
k_attn
=
k_cat
.
to
(
torch
.
float8_e4m3fn
)
v_attn
=
v_nope
.
to
(
torch
.
float8_e4m3fn
)
else
:
q_attn
=
q
k_attn
=
k_cat
v_attn
=
v_nope
if
use_flash_fp8_arch
:
attn_output
,
attn_softmax_lse
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q_attn
,
k
=
k_attn
,
v
=
v_attn
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
chunked_context
.
max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
q_descale
=
q_descal
e
,
k_descale
=
k_descal
e
,
v_descale
=
v_descal
e
,
q_descale
=
Non
e
,
k_descale
=
Non
e
,
v_descale
=
Non
e
,
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_output
,
attn_softmax_lse
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q_attn
,
k
=
k_attn
,
v
=
v_attn
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
...
...
@@ -1124,6 +1132,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
kv_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
),
kv_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
assert
attn_metadata
.
prefill
is
not
None
...
...
@@ -1132,34 +1141,54 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else
:
has_context
=
False
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
if
kv_quant_args
is
not
None
:
kv_nope
=
self
.
kv_b_proj
.
quant_method
.
apply
(
self
.
kv_b_proj
,
kv_c_normed
,
input_quant_args
=
kv_quant_args
)
else
:
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)
if
isinstance
(
kv_nope
,
tuple
):
kv_nope
=
kv_nope
[
0
]
kv_nope
=
kv_nope
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch
=
(
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
if
use_fused_fp8_op
:
from
lightop
import
op
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
q
,
k
,
v
=
op
.
ds_fused_qkv_cast_fp8
(
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_OPT_CAT
:
if
k_nope
.
shape
[
0
]
>
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
q_descale
=
q_descale
.
expand
(
descale_shape
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
use_flash_fp8_arch
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
if
use_flash_fp8_arch
:
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
...
...
@@ -1170,9 +1199,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
q_descale
=
q_descal
e
,
k_descale
=
k_descal
e
,
v_descale
=
v_descal
e
,
q_descale
=
Non
e
,
k_descale
=
Non
e
,
v_descale
=
Non
e
,
return_softmax_lse
=
has_context
,
)
else
:
...
...
@@ -1270,7 +1299,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
decode_q
=
q
[:
num_decode_tokens
]
...
...
@@ -1289,6 +1317,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else
:
kv_cache_dtype_str
=
self
.
kv_cache_dtype
k_c_normed_int8
=
None
k_c_normed_scale
=
None
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
@@ -1301,8 +1331,31 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
)
else
:
if
envs
.
VLLM_FUSED_RN_ROPE_INT8_QUANT
:
k_c_normed_int8
=
torch
.
empty
((
num_actual_toks
,
k_c_normed
.
size
(
-
1
)),
dtype
=
torch
.
int8
,
device
=
q
.
device
)
k_c_normed_scale
=
torch
.
empty
((
num_actual_toks
,
1
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype_str
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA
:
if
has_prefill
:
if
envs
.
VLLM_FUSED_RN_ROPE_INT8_QUANT
:
from
lightop
import
op
op
.
fused_rms_norm_rope_int8quant_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
k_c_normed_int8
,
k_c_normed_scale
,
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
...
...
@@ -1340,6 +1393,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False
,
1e-6
,
)
else
:
if
envs
.
VLLM_FUSED_RN_ROPE_INT8_QUANT
:
from
lightop
import
op
op
.
fused_rms_norm_rope_int8quant_contiguous
(
positions
[:
num_actual_toks
,
...],
q
,
k_pe
.
squeeze
(
1
),
k_c_normed
,
# not normed
key_normed
[:
num_actual_toks
,
...],
# normed
k_c_normed_int8
,
k_c_normed_scale
,
weight
,
cos_sin_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache
,
kv_cache_dtype_str
,
1.0
,
False
,
1e-6
,
)
else
:
fused_rms_norm_rope_contiguous
(
positions
[:
num_actual_toks
,
...],
...
...
@@ -1356,14 +1429,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False
,
1e-6
,
)
if
has_prefill
:
curr_kv_quant
=
None
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
prefill_k_c_normed
=
key_normed
[:
num_actual_toks
,
...]
prefill_k_c_normed
=
prefill_k_c_normed
[
num_decode_tokens
:]
else
:
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
if
envs
.
VLLM_FUSED_RN_ROPE_INT8_QUANT
and
prefill_k_c_normed
is
not
None
:
if
k_c_normed_int8
is
not
None
and
k_c_normed_scale
is
not
None
:
curr_kv_quant
=
[
k_c_normed_int8
[
num_decode_tokens
:],
k_c_normed_scale
[
num_decode_tokens
:]]
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
kv_scale
=
layer
.
_k_scale
)
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
kv_scale
=
layer
.
_k_scale
,
kv_quant_args
=
curr_kv_quant
)
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment