Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9df8da54
Unverified
Commit
9df8da54
authored
Sep 23, 2025
by
Lucas Wilkinson
Committed by
GitHub
Sep 23, 2025
Browse files
[BugFix] Fix MLA assert with CUTLASS MLA (#25478)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
bf68fd76
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
18 deletions
+46
-18
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+46
-18
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
9df8da54
...
@@ -204,7 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
...
@@ -204,7 +204,7 @@ from vllm.attention.backends.utils import get_mla_dims
from
vllm.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -436,6 +436,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
"""
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
@
staticmethod
def
determine_chunked_prefill_workspace_size
(
vllm_config
:
VllmConfig
)
->
int
:
scheduler_config
=
vllm_config
.
scheduler_config
cache_config
=
vllm_config
.
cache_config
model_config
=
vllm_config
.
model_config
chunked_prefill_workspace_size
=
min
(
# Try for 8 full length request or at least 4 pages per-request
max
(
8
*
model_config
.
max_model_len
,
4
*
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
64
*
1024
)
# Enforce that we enough for at least 1 page per request
chunked_prefill_workspace_size
=
max
(
chunked_prefill_workspace_size
,
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
)
return
chunked_prefill_workspace_size
def
__init__
(
self
,
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
layer_names
:
list
[
str
],
...
@@ -448,7 +476,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -448,7 +476,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
scheduler_config
=
vllm_config
.
scheduler_config
scheduler_config
=
vllm_config
.
scheduler_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
cache_config
=
vllm_config
.
cache_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
device
=
device
self
.
device
=
device
...
@@ -468,22 +495,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -468,22 +495,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if
self
.
aot_schedule
:
if
self
.
aot_schedule
:
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
chunked_prefill_workspace_size
=
min
(
self
.
chunked_prefill_workspace_size
=
\
# Max sure there is enough for 8 full length request or at least
self
.
determine_chunked_prefill_workspace_size
(
vllm_config
)
# 4 pages of cache per request
max
(
8
*
self
.
model_config
.
max_model_len
,
4
*
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
64
*
1024
)
assert
self
.
chunked_prefill_workspace_size
>=
\
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# an additional kvcache allgather across the DCP group is therefore
...
@@ -999,6 +1013,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -999,6 +1013,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
dcp_world_size
:
Optional
[
int
]
=
None
self
.
dcp_world_size
:
Optional
[
int
]
=
None
self
.
chunked_prefill_workspace_size
=
\
MLACommonMetadataBuilder
.
determine_chunked_prefill_workspace_size
(
get_current_vllm_config
())
def
_flash_attn_varlen_diff_headdims
(
self
,
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
q
,
k
,
k
,
...
@@ -1513,6 +1531,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1513,6 +1531,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
" for MLACommonImpl"
)
" for MLACommonImpl"
)
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_
=
torch
.
empty
(
(
self
.
chunked_prefill_workspace_size
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
device
=
k_c_normed
.
device
,
dtype
=
k_c_normed
.
dtype
,
)
# The zero fill is required when used with DP + EP
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# to ensure all ranks within a DP group compute the
# same expert outputs.
# same expert outputs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment