Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3bff7958
Commit
3bff7958
authored
Mar 18, 2026
by
yangql
Browse files
x接入mla_cat算子仅在nmz和kvcache-fp8情况下生效,默认关闭,开启需要export VLLM_USE_CAT_MLA=1
parent
7306fe81
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
94 additions
and
7 deletions
+94
-7
vllm/envs.py
vllm/envs.py
+6
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+6
-3
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+0
-3
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+29
-0
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+53
-0
No files found.
vllm/envs.py
View file @
3bff7958
...
...
@@ -296,6 +296,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_USE_CAT_MLA
:
bool
=
False
VLLM_W8A8_BACKEND
:
int
=
3
VLLM_USE_PP_BALANCE
=
True
VLLM_MOE_ROUTER_CAPTURE
:
bool
=
False
...
...
@@ -1819,7 +1820,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_MOE_ALIGN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_CAT_MLA'
,
'False'
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MERGE_ATTN_STATES_OPT"
,
"True"
).
lower
()
in
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
3bff7958
...
...
@@ -2355,9 +2355,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if
fp8_attention
and
get_gcn_arch_name
()
==
"gfx938"
:
assert
decode_ql_nope
.
shape
[
0
]
==
decode_q_pe
.
shape
[
0
]
assert
decode_ql_nope
.
shape
[
1
]
==
decode_q_pe
.
shape
[
1
]
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
decode_ql_nope
,
decode_q_pe
,
layer
.
_q_scale
)
if
envs
.
VLLM_USE_CAT_MLA
:
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
else
:
decode_q
=
self
.
_decode_concat_quant_fp8_op
(
decode_ql_nope
,
decode_q_pe
,
layer
.
_q_scale
)
else
:
decode_q
=
(
decode_ql_nope
,
decode_q_pe
)
if
self
.
dcp_world_size
>
1
:
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
3bff7958
...
...
@@ -211,9 +211,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
self
.
quant_method
=
quant_config
.
get_name
()
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if
self
.
quant_method
==
"moe_wna16"
or
self
.
quant_method
==
"awq_marlin"
:
vllm_config
.
quant_config
=
BlockInt8Config
(
is_checkpoint_int8_serialized
=
True
,
weight_block_size
=
[
128
,
128
])
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
3bff7958
...
...
@@ -28,6 +28,7 @@ from vllm.v1.attention.backend import (
MultipleOf
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
get_gcn_arch_name
from
vllm.v1.attention.backends.utils
import
(
reshape_attn_output_for_spec_decode
,
reshape_query_for_spec_decode
,
...
...
@@ -39,6 +40,7 @@ from vllm.v1.attention.ops.flashmla import (
get_mla_metadata
,
get_mla_metadata_dense_fp8
,
is_flashmla_dense_supported
,
flash_mla_with_kvcache_fp8_with_cat
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm
import
envs
...
...
@@ -249,6 +251,33 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if
isinstance
(
q
,
tuple
):
q_nope
,
q_pe
=
q
if
envs
.
VLLM_USE_CAT_MLA
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
get_gcn_arch_name
()
==
"gfx938"
:
assert
isinstance
(
q_nope
,
torch
.
Tensor
)
assert
isinstance
(
q_pe
,
torch
.
Tensor
)
num_decodes
=
attn_metadata
.
num_decodes
q_nope
=
reshape_query_for_spec_decode
(
q_nope
,
num_decodes
)
q_pe
=
reshape_query_for_spec_decode
(
q_pe
,
num_decodes
)
scheduler_metadata
=
attn_metadata
.
decode
.
scheduler_metadata
assert
q_nope
.
shape
[
0
]
==
num_decodes
o
,
lse
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
,
q_pe
=
q_pe
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
scheduler_metadata
.
tile_scheduler_metadata
,
num_splits
=
scheduler_metadata
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
)
o
=
reshape_attn_output_for_spec_decode
(
o
)
return
o
,
lse
if
envs
.
VLLM_USE_OPT_CAT
and
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
(
concat_helper_decode
,
...
...
vllm/v1/attention/ops/flashmla.py
View file @
3bff7958
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from
typing
import
Optional
,
Tuple
import
torch
...
...
@@ -211,6 +212,58 @@ def flash_mla_with_kvcache_fp8(
return
out
,
softmax_lse
def
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8_with_cat
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
,
)
return
out
,
softmax_lse
#
# TODO: Add fake functions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment