Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18341f14
Commit
18341f14
authored
Feb 24, 2026
by
zhuwenwen
Browse files
support fuse cat + q to fp8 + mla
parent
bac201c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
24 deletions
+96
-24
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+55
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+41
-24
No files found.
vllm/attention/ops/flashmla.py
View file @
18341f14
...
...
@@ -277,6 +277,61 @@ def flash_mla_with_kvcache_fp8(
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_fp8_with_cat
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
,
)
return
out
,
softmax_lse
#
# TODO: Add fake functions
#
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
18341f14
...
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8_with_cat
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
from
vllm.logger
import
init_logger
...
...
@@ -181,32 +182,48 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
attn_metadata
.
decode
is
not
None
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
if
envs
.
VLLM_USE_CAT_MLA
:
o
,
_
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
if
envs
.
VLLM_USE_OPT_CAT
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment