Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
48fc8b1e
Unverified
Commit
48fc8b1e
authored
Nov 19, 2025
by
Lucas Wilkinson
Committed by
GitHub
Nov 19, 2025
Browse files
[BugFix] Fix async-scheduling + FlashAttn MLA (#28990)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
1ffe934c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
10 deletions
+18
-10
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+9
-6
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+1
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-3
No files found.
vllm/v1/attention/backends/mla/common.py
View file @
48fc8b1e
...
@@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
dcp_local_seq_lens
=
common_attn_metadata
.
dcp_local_seq_lens
dcp_local_seq_lens
=
common_attn_metadata
.
dcp_local_seq_lens
dcp_local_seq_lens_cpu
=
common_attn_metadata
.
dcp_local_seq_lens_cpu
query_seq_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
query_seq_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
...
@@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata
=
None
decode_metadata
=
None
if
num_decodes
>
0
:
if
num_decodes
>
0
:
dcp_tot_seq_lens_device
=
None
if
self
.
dcp_world_size
>
1
:
dcp_tot_seq_lens_device
=
seq_lens
[:
num_decodes
]
seq_lens_cpu
=
dcp_local_seq_lens_cpu
seq_lens
=
dcp_local_seq_lens
decode_metadata
=
self
.
_build_decode
(
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
seq_lens_cpu
=
seq_lens_cpu
[:
num_decodes
],
seq_lens_cpu
=
seq_lens_cpu
[:
num_decodes
],
seq_lens_device
=
dcp_local_seq_lens
[:
num_decodes
]
seq_lens_device
=
seq_lens
[:
num_decodes
],
if
self
.
dcp_world_size
>
1
and
dcp_local_seq_lens
is
not
None
else
seq_lens
[:
num_decodes
],
query_start_loc_cpu
=
query_start_loc_cpu
[:
num_decodes
+
1
],
query_start_loc_cpu
=
query_start_loc_cpu
[:
num_decodes
+
1
],
query_start_loc_device
=
query_start_loc
[:
num_decodes
+
1
],
query_start_loc_device
=
query_start_loc
[:
num_decodes
+
1
],
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
dcp_tot_seq_lens_device
=
seq_lens
[:
num_decodes
]
dcp_tot_seq_lens_device
=
dcp_tot_seq_lens_device
,
if
self
.
dcp_world_size
>
1
else
None
,
)
)
attn_metadata
=
self
.
metadata_cls
(
attn_metadata
=
self
.
metadata_cls
(
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
48fc8b1e
...
@@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
...
@@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
)
->
FlashAttnMLADecodeMetadata
:
)
->
FlashAttnMLADecodeMetadata
:
query_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
query_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
max_query_len
=
query_lens_cpu
.
max
().
item
()
max_query_len
=
query_lens_cpu
.
max
().
item
()
max_seq_len
=
seq_lens_
device
.
max
().
item
()
max_seq_len
=
seq_lens_
cpu
.
max
().
item
()
# For Flash Attention MLA + full cudagraph
# For Flash Attention MLA + full cudagraph
max_num_splits
=
0
max_num_splits
=
0
...
...
vllm/v1/attention/backends/utils.py
View file @
48fc8b1e
...
@@ -92,6 +92,7 @@ class CommonAttentionMetadata:
...
@@ -92,6 +92,7 @@ class CommonAttentionMetadata:
encoder_seq_lens
:
np
.
ndarray
|
None
=
None
encoder_seq_lens
:
np
.
ndarray
|
None
=
None
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
dcp_local_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
"""Sequence lengths of the local rank in decode context parallelism world"""
"""Sequence lengths of the local rank in decode context parallelism world"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
48fc8b1e
...
@@ -1451,9 +1451,12 @@ class GPUModelRunner(
...
@@ -1451,9 +1451,12 @@ class GPUModelRunner(
num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs
:
num_reqs
]
]
dcp_local_seq_lens
=
(
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs
]
if
self
.
dcp_world_size
>
1
else
None
dcp_local_seq_lens
,
dcp_local_seq_lens_cpu
=
None
,
None
)
if
self
.
dcp_world_size
>
1
:
dcp_local_seq_lens
=
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs
]
dcp_local_seq_lens_cpu
=
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
spec_decode_common_attn_metadata
=
None
spec_decode_common_attn_metadata
=
None
if
for_cudagraph_capture
:
if
for_cudagraph_capture
:
...
@@ -1521,6 +1524,7 @@ class GPUModelRunner(
...
@@ -1521,6 +1524,7 @@ class GPUModelRunner(
causal
=
True
,
causal
=
True
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
dcp_local_seq_lens_cpu
=
dcp_local_seq_lens_cpu
,
)
)
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment