Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd8da29a
Unverified
Commit
bd8da29a
authored
Feb 03, 2026
by
Matthew Bonanni
Committed by
GitHub
Feb 03, 2026
Browse files
[Bugfix] Fix sparse MLA metadata building (#33579)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
2a99c5a6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
31 deletions
+22
-31
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+22
-31
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
bd8da29a
...
...
@@ -522,22 +522,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
assert
(
attn_metadata
.
num_decodes
is
not
None
and
attn_metadata
.
num_prefills
is
not
None
and
attn_metadata
.
num_decode_tokens
is
not
None
)
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
decode_q
=
q
[:
num_decode_tokens
]
prefill_q
=
q
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
...
...
@@ -555,27 +539,32 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Sparse MLA impls only support forward_mqa (decode-style attention)
is_sparse_impl
=
isinstance
(
self
.
impl
,
SparseMLAAttentionImpl
)
if
has_prefill
and
not
is_sparse_impl
:
if
is_sparse_impl
:
num_mqa_tokens
=
q
.
size
(
0
)
num_mha_tokens
=
0
else
:
assert
(
attn_metadata
.
num_decodes
is
not
None
and
attn_metadata
.
num_prefills
is
not
None
and
attn_metadata
.
num_decode_tokens
is
not
None
)
num_mqa_tokens
=
attn_metadata
.
num_decode_tokens
num_mha_tokens
=
q
.
size
(
0
)
-
num_mqa_tokens
if
num_mha_tokens
>
0
:
self
.
impl
.
forward_mha
(
prefill_q
,
prefill_
k_c_normed
,
prefill_k_pe
,
q
[
num_mqa_tokens
:]
,
k_c_normed
[
num_mqa_tokens
:]
,
k_pe
[
num_mqa_tokens
:]
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
output
=
output
[
num_
decode
_tokens
:],
output
=
output
[
num_
mqa
_tokens
:],
)
if
has_decode
or
(
has_prefill
and
is_sparse_impl
):
# For sparse impl, we always use forward_mqa for all tokens
# For non-sparse impl, we only use forward_mqa for decode tokens
if
is_sparse_impl
:
mqa_q
=
q
mqa_output_slice
=
output
else
:
assert
attn_metadata
.
decode
is
not
None
mqa_q
=
decode_q
mqa_output_slice
=
output
[:
num_decode_tokens
]
if
num_mqa_tokens
>
0
:
mqa_q
=
q
[:
num_mqa_tokens
]
mqa_output_slice
=
output
[:
num_mqa_tokens
]
mqa_q_nope
,
mqa_q_pe
=
mqa_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
...
...
@@ -644,6 +633,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
mqa_q
=
get_dcp_group
().
all_gather
(
mqa_q
,
dim
=
1
)
# call decode attn
if
not
is_sparse_impl
:
assert
attn_metadata
.
decode
is
not
None
attn_out
,
lse
=
self
.
impl
.
forward_mqa
(
mqa_q
,
kv_cache
,
attn_metadata
,
self
)
# correct dcp attn_out with lse.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment