Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad7c14d5
Commit
ad7c14d5
authored
Feb 11, 2026
by
zhuwenwen
Browse files
support fuse cat + q to fp8 + mla
parent
ab674544
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
26 deletions
+43
-26
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+42
-25
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
ad7c14d5
...
...
@@ -1318,7 +1318,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False
,
1e-6
,
)
els
e
:
if
has_decod
e
:
q_tensor
=
torch
.
randn
(
q
.
shape
[
0
],
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
q_quant
=
torch
.
empty_like
(
q_tensor
,
dtype
=
torch
.
float8_e4m3fn
,
device
=
q
.
device
)
q_scale
=
torch
.
empty
(
q
.
shape
[
0
],
dtype
=
torch
.
float32
,
device
=
q
.
device
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
ad7c14d5
...
...
@@ -186,7 +186,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o
,
_
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
)
.
view
(
torch
.
float8_e4m3fn
)
,
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
...
...
@@ -199,32 +199,49 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k
=
k_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
if
envs
.
VLLM_USE_CAT_MLA
:
o
,
_
=
flash_mla_with_kvcache_fp8_with_cat
(
q_nope
=
q_nope
.
unsqueeze
(
1
),
q_pe
=
q_pe
.
unsqueeze
(
1
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_CAT
:
if
q_nope
.
shape
[
0
]
<
1024
:
from
vllm.v1.attention.backends.mla.test_concat
import
concat_helper_decode
q
=
concat_helper_decode
(
q_nope
,
q_pe
,
dim
=
2
)
\
.
unsqueeze
(
1
)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
else
:
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache_fp8
(
q
=
q
.
to
(
torch
.
float8_e4m3fn
),
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
).
view
(
torch
.
float8_e4m3fn
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
q_scale
,
descale_k
=
k_scale
,
)
else
:
if
not
envs
.
VLLM_USE_CAT_MLA
or
kv_cache_dtype
==
"fp8_e4m3"
:
if
envs
.
VLLM_USE_OPT_CAT
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment