Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab33d2a6
Unverified
Commit
ab33d2a6
authored
Feb 17, 2026
by
Wentao Ye
Committed by
GitHub
Feb 17, 2026
Browse files
[Feature] Decode Context Parallel support for GPU model runner v2 (#34179)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
be3af2d2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
117 additions
and
3 deletions
+117
-3
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+28
-0
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+43
-3
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+20
-0
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+4
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+22
-0
No files found.
vllm/v1/worker/gpu/attn_utils.py
View file @
ab33d2a6
...
...
@@ -12,6 +12,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.utils
import
get_dcp_local_seq_lens
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
...
...
@@ -143,6 +144,28 @@ def build_slot_mappings_by_layer(
return
slot_mappings_by_layer
def
prepare_dcp_local_seq_lens
(
dcp_local_seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_reqs
:
int
,
dcp_size
:
int
,
dcp_rank
:
int
,
cp_kv_cache_interleave_size
:
int
,
)
->
None
:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if
dcp_size
<=
1
:
return
local_seq_lens
=
get_dcp_local_seq_lens
(
seq_lens
[:
num_reqs
],
dcp_size
=
dcp_size
,
dcp_rank
=
dcp_rank
,
cp_kv_cache_interleave_size
=
cp_kv_cache_interleave_size
,
)
dcp_local_seq_lens
[:
num_reqs
].
copy_
(
local_seq_lens
,
non_blocking
=
True
)
dcp_local_seq_lens
[
num_reqs
:].
zero_
()
def
build_attn_metadata
(
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
num_reqs
:
int
,
...
...
@@ -155,9 +178,13 @@ def build_attn_metadata(
block_tables
:
Sequence
[
torch
.
Tensor
],
slot_mappings
:
torch
.
Tensor
,
kv_cache_config
:
KVCacheConfig
,
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
)
->
dict
[
str
,
Any
]:
seq_lens
=
seq_lens
[:
num_reqs
]
if
dcp_local_seq_lens
is
not
None
:
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
attn_metadata
:
dict
[
str
,
Any
]
=
{}
kv_cache_groups
=
kv_cache_config
.
kv_cache_groups
for
i
,
kv_cache_spec
in
enumerate
(
kv_cache_groups
):
...
...
@@ -175,6 +202,7 @@ def build_attn_metadata(
block_table_tensor
=
block_table
,
slot_mapping
=
slot_mapping
,
causal
=
True
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
attn_metadata_builder
=
attn_metadata_builders
[
i
]
...
...
vllm/v1/worker/gpu/block_table.py
View file @
ab33d2a6
...
...
@@ -4,6 +4,7 @@ from collections.abc import Iterable
import
torch
from
vllm.distributed
import
get_dcp_group
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
...
...
@@ -18,19 +19,36 @@ class BlockTables:
max_num_batched_tokens
:
int
,
max_model_len
:
int
,
device
:
torch
.
device
,
cp_kv_cache_interleave_size
:
int
=
1
,
):
self
.
block_sizes
=
block_sizes
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_model_len
=
max_model_len
self
.
device
=
device
assert
cp_kv_cache_interleave_size
>=
1
self
.
cp_kv_cache_interleave_size
=
cp_kv_cache_interleave_size
try
:
dcp
=
get_dcp_group
()
self
.
dcp_world_size
,
self
.
dcp_rank
=
dcp
.
world_size
,
dcp
.
rank_in_group
except
AssertionError
:
self
.
dcp_world_size
,
self
.
dcp_rank
=
1
,
0
# TODO(wentao): PCP supprot
self
.
total_cp_world_size
=
self
.
dcp_world_size
self
.
total_cp_rank
=
self
.
dcp_rank
self
.
num_kv_cache_groups
=
len
(
self
.
block_sizes
)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self
.
block_tables
:
list
[
StagedWriteTensor
]
=
[]
for
i
in
range
(
self
.
num_kv_cache_groups
):
block_size
=
self
.
block_sizes
[
i
]
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
block_size
)
# with DCP, a request's KV is sharded across
# ranks, so one physical block on this rank
# corresponds to `block_size * total_cp_world_size`
# tokens in the global (unsharded) sequence.
virtual_block_size
=
block_size
*
self
.
total_cp_world_size
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
virtual_block_size
)
block_table
=
StagedWriteTensor
(
(
self
.
max_num_reqs
,
max_num_blocks
),
dtype
=
torch
.
int32
,
...
...
@@ -131,6 +149,9 @@ class BlockTables:
self
.
block_sizes_tensor
,
self
.
slot_mappings
,
self
.
slot_mappings
.
stride
(
0
),
TOTAL_CP_WORLD_SIZE
=
self
.
total_cp_world_size
,
TOTAL_CP_RANK
=
self
.
total_cp_rank
,
CP_KV_CACHE_INTERLEAVE_SIZE
=
self
.
cp_kv_cache_interleave_size
,
PAD_ID
=
PAD_SLOT_ID
,
TRITON_BLOCK_SIZE
=
1024
,
# type: ignore
)
...
...
@@ -183,6 +204,9 @@ def _compute_slot_mappings_kernel(
block_sizes
,
# [num_kv_cache_groups]
slot_mappings_ptr
,
# [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride
,
TOTAL_CP_WORLD_SIZE
:
tl
.
constexpr
,
TOTAL_CP_RANK
:
tl
.
constexpr
,
CP_KV_CACHE_INTERLEAVE_SIZE
:
tl
.
constexpr
,
PAD_ID
:
tl
.
constexpr
,
TRITON_BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -201,6 +225,7 @@ def _compute_slot_mappings_kernel(
block_table_ptr
=
_load_ptr
(
block_table_ptrs
+
group_id
,
tl
.
int32
)
block_table_stride
=
tl
.
load
(
block_table_strides
+
group_id
)
block_size
=
tl
.
load
(
block_sizes
+
group_id
)
virtual_block_size
=
block_size
*
TOTAL_CP_WORLD_SIZE
req_state_idx
=
tl
.
load
(
idx_mapping
+
batch_idx
)
start_idx
=
tl
.
load
(
query_start_loc
+
batch_idx
)
...
...
@@ -208,11 +233,26 @@ def _compute_slot_mappings_kernel(
for
i
in
range
(
start_idx
,
end_idx
,
TRITON_BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
TRITON_BLOCK_SIZE
)
positions
=
tl
.
load
(
pos
+
offset
,
mask
=
offset
<
end_idx
,
other
=
0
)
block_indices
=
positions
//
block_size
block_indices
=
positions
//
virtual_
block_size
block_numbers
=
tl
.
load
(
block_table_ptr
+
req_state_idx
*
block_table_stride
+
block_indices
)
slot_ids
=
block_numbers
*
block_size
+
positions
%
block_size
virtual_block_offsets
=
positions
-
block_indices
*
virtual_block_size
# determine whether the token is stored on this CP rank.
is_local
=
(
virtual_block_offsets
//
CP_KV_CACHE_INTERLEAVE_SIZE
)
%
TOTAL_CP_WORLD_SIZE
==
TOTAL_CP_RANK
# mapping virture block offsets to local block offsets.
local_block_offsets
=
(
virtual_block_offsets
//
(
TOTAL_CP_WORLD_SIZE
*
CP_KV_CACHE_INTERLEAVE_SIZE
)
)
*
CP_KV_CACHE_INTERLEAVE_SIZE
+
(
virtual_block_offsets
%
CP_KV_CACHE_INTERLEAVE_SIZE
)
# physical slot index
slot_ids
=
block_numbers
*
block_size
+
local_block_offsets
slot_ids
=
tl
.
where
(
is_local
,
slot_ids
,
PAD_ID
)
tl
.
store
(
slot_mapping_ptr
+
offset
,
slot_ids
,
mask
=
offset
<
end_idx
)
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
ab33d2a6
...
...
@@ -10,6 +10,7 @@ from tqdm import tqdm
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed
import
get_dcp_group
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
...
...
@@ -17,6 +18,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_slot_mappings_by_layer
,
prepare_dcp_local_seq_lens
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
...
...
@@ -257,6 +259,23 @@ def prepare_inputs_to_capture(
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
try
:
dcp_group
=
get_dcp_group
()
dcp_world_size
=
dcp_group
.
world_size
dcp_rank
=
dcp_group
.
rank_in_group
except
AssertionError
:
dcp_world_size
=
1
dcp_rank
=
0
if
dcp_world_size
>
1
:
prepare_dcp_local_seq_lens
(
input_buffers
.
dcp_local_seq_lens
,
input_buffers
.
seq_lens
,
num_reqs
,
dcp_size
=
dcp_world_size
,
dcp_rank
=
dcp_rank
,
cp_kv_cache_interleave_size
=
block_tables
.
cp_kv_cache_interleave_size
,
)
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input_block_tables
]
slot_mappings
=
block_tables
.
slot_mappings
[:,
:
num_tokens
]
slot_mappings_by_layer
=
build_slot_mappings_by_layer
(
...
...
@@ -275,5 +294,6 @@ def prepare_inputs_to_capture(
block_tables
=
input_block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
kv_cache_config
,
dcp_local_seq_lens
=
input_buffers
.
dcp_local_seq_lens
,
)
return
attn_metadata
,
slot_mappings_by_layer
vllm/v1/worker/gpu/input_batch.py
View file @
ab33d2a6
...
...
@@ -27,6 +27,10 @@ class InputBuffers:
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# DCP: per-request local seq_lens buffer
self
.
dcp_local_seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
@
dataclass
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
ab33d2a6
...
...
@@ -11,6 +11,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
(
get_dcp_group
,
get_pp_group
,
prepare_communication_buffer_for_model
,
)
...
...
@@ -24,6 +25,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
...
...
@@ -31,6 +33,7 @@ from vllm.v1.worker.gpu.attn_utils import (
get_kv_cache_spec
,
init_attn_backend
,
init_kv_cache
,
prepare_dcp_local_seq_lens
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
async_copy_to_gpu
...
...
@@ -248,11 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
cp_kv_cache_interleave_size
=
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
),
)
self
.
attn_backends
,
self
.
attn_metadata_builders
=
init_attn_backend
(
self
.
kv_cache_config
,
self
.
vllm_config
,
self
.
device
)
check_attention_cp_compatibility
(
self
.
vllm_config
)
if
self
.
do_spec_decode
:
# HACK(woosuk)
self
.
speculator
.
set_attn
(
...
...
@@ -294,6 +301,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
,
)
input_batch
.
attn_metadata
=
attn_metadata
input_batch
.
slot_mappings
=
slot_mappings_by_layer
...
...
@@ -627,6 +635,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
if
dcp_size
>
1
:
prepare_dcp_local_seq_lens
(
self
.
input_buffers
.
dcp_local_seq_lens
,
seq_lens
,
num_reqs
,
dcp_size
=
dcp_size
,
dcp_rank
=
get_dcp_group
().
rank_in_group
,
cp_kv_cache_interleave_size
=
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
),
)
# Prepare M-RoPE positions.
if
self
.
uses_mrope
:
self
.
mrope_states
.
prepare_mrope_positions
(
...
...
@@ -674,6 +695,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
,
)
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment