Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b736e1c
Unverified
Commit
3b736e1c
authored
Oct 09, 2025
by
Ming Yang
Committed by
GitHub
Oct 09, 2025
Browse files
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 (#25049)
Signed-off-by:
Ming Yang
<
minos.future@gmail.com
>
parent
2c1c7dfb
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
45 additions
and
13 deletions
+45
-13
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+14
-2
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+7
-9
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+2
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+3
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+14
-1
No files found.
cmake/external_projects/vllm_flash_attn.cmake
View file @
3b736e1c
...
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
4695e6bed5366c41e28c06cd86170166e4f43d00
GIT_TAG
8f468e7da54a8e2f98abfa7c38636aac91c0cba1
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
vllm/v1/attention/backends/mla/common.py
View file @
3b736e1c
...
...
@@ -370,6 +370,7 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
class
MLACommonDecodeMetadata
:
block_table
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
dcp_tot_seq_lens
:
Optional
[
torch
.
Tensor
]
D
=
TypeVar
(
"D"
,
bound
=
MLACommonDecodeMetadata
)
...
...
@@ -682,10 +683,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu
:
torch
.
Tensor
,
query_start_loc_device
:
torch
.
Tensor
,
num_decode_tokens
:
int
,
dcp_tot_seq_lens_device
:
Optional
[
torch
.
Tensor
],
)
->
MLACommonDecodeMetadata
:
return
MLACommonDecodeMetadata
(
block_table
=
block_table_tensor
,
seq_lens
=
seq_lens_device
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
)
def
build_for_cudagraph_capture
(
...
...
@@ -727,6 +730,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
dcp_local_seq_lens
=
common_attn_metadata
.
dcp_local_seq_lens
query_seq_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
...
...
@@ -742,7 +746,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# Note(hc): update seq_lens of decode reqs under DCP.
if
self
.
dcp_world_size
>
1
:
seq_lens
[:
num_decodes
]
=
seq_lens
[:
num_decodes
]
//
self
.
dcp_world_size
+
(
assert
dcp_local_seq_lens
is
not
None
dcp_local_seq_lens
[:
num_decodes
]
=
seq_lens
[
:
num_decodes
]
//
self
.
dcp_world_size
+
(
self
.
dcp_rank
<=
(
seq_lens
[:
num_decodes
]
-
1
)
%
self
.
dcp_world_size
)
...
...
@@ -899,10 +906,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata
=
self
.
_build_decode
(
block_table_tensor
=
block_table_tensor
[:
num_decodes
,
...],
seq_lens_cpu
=
seq_lens_cpu
[:
num_decodes
],
seq_lens_device
=
seq_lens
[:
num_decodes
],
seq_lens_device
=
dcp_local_seq_lens
[:
num_decodes
]
if
self
.
dcp_world_size
>
1
and
dcp_local_seq_lens
is
not
None
else
seq_lens
[:
num_decodes
],
query_start_loc_cpu
=
query_start_loc_cpu
[:
num_decodes
+
1
],
query_start_loc_device
=
query_start_loc
[:
num_decodes
+
1
],
num_decode_tokens
=
num_decode_tokens
,
dcp_tot_seq_lens_device
=
seq_lens
[:
num_decodes
]
if
self
.
dcp_world_size
>
1
else
None
,
)
attn_metadata
=
self
.
metadata_cls
(
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
3b736e1c
...
...
@@ -17,7 +17,6 @@ from vllm.attention.utils.fa_utils import (
get_flash_attn_version
,
)
from
vllm.config
import
VllmConfig
from
vllm.distributed.parallel_state
import
get_dcp_group
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -107,12 +106,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture.
self
.
max_num_splits
=
envs
.
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.
self
.
reorder_batch_threshold
=
(
1
if
get_dcp_group
().
world_size
>
1
else
self
.
reorder_batch_threshold
)
def
_schedule_decode
(
self
,
num_reqs
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
...
...
@@ -121,7 +114,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
batch_size
=
num_reqs
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_seq_len
,
num_heads_q
=
self
.
num_heads
,
num_heads_q
=
self
.
num_heads
*
self
.
dcp_world_size
,
num_heads_kv
=
1
,
headdim
=
self
.
mla_dims
.
qk_rope_head_dim
,
cache_seqlens
=
seqlens
,
...
...
@@ -142,10 +135,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
query_start_loc_cpu
:
torch
.
Tensor
,
query_start_loc_device
:
torch
.
Tensor
,
num_decode_tokens
:
int
,
dcp_tot_seq_lens_device
:
Optional
[
torch
.
Tensor
],
)
->
FlashAttnMLADecodeMetadata
:
query_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
max_query_len
=
query_lens_cpu
.
max
().
item
()
max_seq_len
=
seq_lens_
cpu
.
max
().
item
()
max_seq_len
=
seq_lens_
device
.
max
().
item
()
scheduler_metadata
=
self
.
_schedule_decode
(
num_reqs
=
seq_lens_cpu
.
numel
(),
...
...
@@ -188,6 +182,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_seq_len
=
max_seq_len
,
scheduler_metadata
=
scheduler_metadata
,
max_num_splits
=
max_num_splits
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
)
...
...
@@ -289,6 +284,9 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
fa_version
=
3
,
# only version 3 is supported
scheduler_metadata
=
attn_metadata
.
decode
.
scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
.
max_num_splits
,
cp_world_size
=
self
.
dcp_world_size
,
cp_rank
=
self
.
dcp_rank
,
cp_tot_seqused_k
=
attn_metadata
.
decode
.
dcp_tot_seq_lens
,
)
if
self
.
need_to_return_lse_for_decode
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
3b736e1c
...
...
@@ -106,6 +106,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
query_start_loc_cpu
:
torch
.
Tensor
,
query_start_loc_device
:
torch
.
Tensor
,
num_decode_tokens
:
int
,
dcp_tot_seq_lens_device
:
Optional
[
torch
.
Tensor
],
)
->
FlashMLADecodeMetadata
:
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens_device
,
...
...
@@ -146,6 +147,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
seq_lens
=
seq_lens_device
,
tile_scheduler_metadata
=
tile_scheduler_metadata
,
num_splits
=
num_splits
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
)
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
3b736e1c
...
...
@@ -116,6 +116,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
query_start_loc_cpu
:
torch
.
Tensor
,
query_start_loc_device
:
torch
.
Tensor
,
num_decode_tokens
:
int
,
dcp_tot_seq_lens_device
:
Optional
[
torch
.
Tensor
],
)
->
AiterMLADecodeMetadata
:
page_size
=
self
.
kv_cache_spec
.
block_size
block_table_bounds
=
(
seq_lens_device
+
page_size
-
1
)
//
page_size
...
...
@@ -174,6 +175,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
qo_indptr
=
qo_indptr
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
)
return
attn_metadata
...
...
vllm/v1/attention/backends/utils.py
View file @
3b736e1c
...
...
@@ -93,6 +93,9 @@ class CommonAttentionMetadata:
# Needed by CrossAttentionBuilder
encoder_seq_lens
:
Optional
[
np
.
ndarray
]
=
None
dcp_local_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
"""Sequence lengths of the local rank in decode context parallelism world"""
def
slice_query_start_locs
(
query_start_loc
:
torch
.
Tensor
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
3b736e1c
...
...
@@ -597,6 +597,7 @@ class EagleProposer:
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
[
token_indices
],
causal
=
True
,
dcp_local_seq_lens
=
common_attn_metadata
.
dcp_local_seq_lens
,
)
token_indices_to_sample
=
(
...
...
@@ -868,6 +869,7 @@ class EagleProposer:
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
[
token_indices
],
causal
=
True
,
dcp_local_seq_lens
=
common_attn_metadata
.
dcp_local_seq_lens
,
)
return
spec_common_attn_metadata
,
token_indices
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3b736e1c
...
...
@@ -398,6 +398,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
...
...
@@ -581,7 +585,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if
self
.
dcp_world_size
>
1
:
if
(
self
.
dcp_world_size
>
1
and
envs
.
VLLM_ATTENTION_BACKEND
!=
"FLASH_ATTN_MLA"
):
assert
self
.
reorder_batch_threshold
==
1
,
(
"DCP not support reorder_batch_threshold > 1 now."
)
...
...
@@ -1335,6 +1342,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_logits_indices
=
logits_indices
.
size
(
0
),
causal
=
True
,
encoder_seq_lens
=
encoder_seq_lens
,
dcp_local_seq_lens
=
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs
]
if
self
.
dcp_world_size
>
1
else
None
,
)
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
...
...
@@ -3310,6 +3320,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_group_id
].
slot_mapping
.
gpu
[:
num_tokens
],
causal
=
True
,
dcp_local_seq_lens
=
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs
]
if
self
.
dcp_world_size
>
1
else
None
,
)
for
attn_group
in
self
.
attn_groups
[
kv_cache_group_id
]:
if
ubatch_slices
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment