Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
95be2a7f
Unverified
Commit
95be2a7f
authored
Feb 18, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 18, 2026
Browse files
[Model Runner V2] Minor simplification for DCP (#34786)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
0e60c925
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
95 deletions
+111
-95
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+0
-24
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+30
-40
vllm/v1/worker/gpu/cp_utils.py
vllm/v1/worker/gpu/cp_utils.py
+61
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+2
-18
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+18
-13
No files found.
vllm/v1/worker/gpu/attn_utils.py
View file @
95be2a7f
...
...
@@ -12,7 +12,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.utils
import
get_dcp_local_seq_lens
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
...
...
@@ -144,28 +143,6 @@ def build_slot_mappings_by_layer(
return
slot_mappings_by_layer
def
prepare_dcp_local_seq_lens
(
dcp_local_seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_reqs
:
int
,
dcp_size
:
int
,
dcp_rank
:
int
,
cp_kv_cache_interleave_size
:
int
,
)
->
None
:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if
dcp_size
<=
1
:
return
local_seq_lens
=
get_dcp_local_seq_lens
(
seq_lens
[:
num_reqs
],
dcp_size
=
dcp_size
,
dcp_rank
=
dcp_rank
,
cp_kv_cache_interleave_size
=
cp_kv_cache_interleave_size
,
)
dcp_local_seq_lens
[:
num_reqs
].
copy_
(
local_seq_lens
,
non_blocking
=
True
)
dcp_local_seq_lens
[
num_reqs
:].
zero_
()
def
build_attn_metadata
(
attn_metadata_builders
:
list
[
AttentionMetadataBuilder
],
num_reqs
:
int
,
...
...
@@ -181,7 +158,6 @@ def build_attn_metadata(
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
,
)
->
dict
[
str
,
Any
]:
seq_lens
=
seq_lens
[:
num_reqs
]
if
dcp_local_seq_lens
is
not
None
:
dcp_local_seq_lens
=
dcp_local_seq_lens
[:
num_reqs
]
...
...
vllm/v1/worker/gpu/block_table.py
View file @
95be2a7f
...
...
@@ -4,7 +4,6 @@ from collections.abc import Iterable
import
torch
from
vllm.distributed
import
get_dcp_group
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
...
...
@@ -19,36 +18,29 @@ class BlockTables:
max_num_batched_tokens
:
int
,
max_model_len
:
int
,
device
:
torch
.
device
,
cp_kv_cache_interleave_size
:
int
=
1
,
cp_size
:
int
=
1
,
cp_rank
:
int
=
0
,
cp_interleave
:
int
=
1
,
):
self
.
block_sizes
=
block_sizes
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_model_len
=
max_model_len
self
.
device
=
device
assert
cp_kv_cache_interleave_size
>=
1
self
.
cp_kv_cache_interleave_size
=
cp_kv_cache_interleave_size
try
:
dcp
=
get_dcp_group
()
self
.
dcp_world_size
,
self
.
dcp_rank
=
dcp
.
world_size
,
dcp
.
rank_in_group
except
AssertionError
:
self
.
dcp_world_size
,
self
.
dcp_rank
=
1
,
0
# TODO(wentao): PCP supprot
self
.
total_cp_world_size
=
self
.
dcp_world_size
self
.
total_cp_rank
=
self
.
dcp_rank
self
.
cp_size
=
cp_size
self
.
cp_rank
=
cp_rank
self
.
cp_interleave
=
cp_interleave
self
.
num_kv_cache_groups
=
len
(
self
.
block_sizes
)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self
.
block_tables
:
list
[
StagedWriteTensor
]
=
[]
for
i
in
range
(
self
.
num_kv_cache_groups
):
block_size
=
self
.
block_sizes
[
i
]
# with DCP, a request's KV is sharded across
# ranks, so one physical block on this rank
# corresponds to `block_size * total_cp_world_size`
# tokens in the global (unsharded) sequence.
virtual_block_size
=
block_size
*
self
.
total_cp_world_size
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
virtual_block_size
)
# When using DCP, each request's KV cache is sharded among different ranks.
# As a result, one block on the current rank covers `block_size * cp_size`
# tokens in the full, global (unsharded) sequence.
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
block_size
*
self
.
cp_size
)
block_table
=
StagedWriteTensor
(
(
self
.
max_num_reqs
,
max_num_blocks
),
dtype
=
torch
.
int32
,
...
...
@@ -149,9 +141,9 @@ class BlockTables:
self
.
block_sizes_tensor
,
self
.
slot_mappings
,
self
.
slot_mappings
.
stride
(
0
),
TOTAL_CP_WORLD_SIZE
=
self
.
total_cp_world_size
,
TOTAL_CP_RANK
=
self
.
total_cp_rank
,
CP_
KV_CACHE_
INTERLEAVE
_SIZE
=
self
.
cp_
kv_cache_
interleave
_size
,
self
.
cp_rank
,
CP_SIZE
=
self
.
cp_size
,
CP_INTERLEAVE
=
self
.
cp_interleave
,
PAD_ID
=
PAD_SLOT_ID
,
TRITON_BLOCK_SIZE
=
1024
,
# type: ignore
)
...
...
@@ -204,9 +196,9 @@ def _compute_slot_mappings_kernel(
block_sizes
,
# [num_kv_cache_groups]
slot_mappings_ptr
,
# [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride
,
TOTAL_CP_WORLD_SIZE
:
tl
.
constexpr
,
TOTAL_CP_RANK
:
tl
.
constexpr
,
CP_
KV_CACHE_
INTERLEAVE
_SIZE
:
tl
.
constexpr
,
cp_rank
,
CP_SIZE
:
tl
.
constexpr
,
CP_INTERLEAVE
:
tl
.
constexpr
,
PAD_ID
:
tl
.
constexpr
,
TRITON_BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -225,7 +217,6 @@ def _compute_slot_mappings_kernel(
block_table_ptr
=
_load_ptr
(
block_table_ptrs
+
group_id
,
tl
.
int32
)
block_table_stride
=
tl
.
load
(
block_table_strides
+
group_id
)
block_size
=
tl
.
load
(
block_sizes
+
group_id
)
virtual_block_size
=
block_size
*
TOTAL_CP_WORLD_SIZE
req_state_idx
=
tl
.
load
(
idx_mapping
+
batch_idx
)
start_idx
=
tl
.
load
(
query_start_loc
+
batch_idx
)
...
...
@@ -233,26 +224,25 @@ def _compute_slot_mappings_kernel(
for
i
in
range
(
start_idx
,
end_idx
,
TRITON_BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
TRITON_BLOCK_SIZE
)
positions
=
tl
.
load
(
pos
+
offset
,
mask
=
offset
<
end_idx
,
other
=
0
)
block_indices
=
positions
//
virtual_block_size
block_indices
=
positions
//
(
block_size
*
CP_SIZE
)
block_offsets
=
positions
%
(
block_size
*
CP_SIZE
)
block_numbers
=
tl
.
load
(
block_table_ptr
+
req_state_idx
*
block_table_stride
+
block_indices
)
virtual_block_offsets
=
positions
-
block_indices
*
virtual_block_size
# determine whether the token is stored on this CP rank.
is_local
=
(
virtual_block_offsets
//
CP_KV_CACHE_INTERLEAVE_SIZE
)
%
TOTAL_CP_WORLD_SIZE
==
TOTAL_CP_RANK
# mapping virture block offsets to local block offsets.
local_block_offsets
=
(
virtual_block_offsets
//
(
TOTAL_CP_WORLD_SIZE
*
CP_KV_CACHE_INTERLEAVE_SIZE
)
)
*
CP_KV_CACHE_INTERLEAVE_SIZE
+
(
virtual_block_offsets
%
CP_KV_CACHE_INTERLEAVE_SIZE
)
if
CP_SIZE
==
1
:
# Common case: Context parallelism is not used.
slot_ids
=
block_numbers
*
block_size
+
block_offsets
else
:
# Context parallelism is used.
is_local
=
block_offsets
//
CP_INTERLEAVE
%
CP_SIZE
==
cp_rank
rounds
=
block_offsets
//
(
CP_INTERLEAVE
*
CP_SIZE
)
remainder
=
block_offsets
%
CP_INTERLEAVE
local_offsets
=
rounds
*
CP_INTERLEAVE
+
remainder
slot_ids
=
block_numbers
*
block_size
+
local_offsets
slot_ids
=
tl
.
where
(
is_local
,
slot_ids
,
PAD_ID
)
# physical slot index
slot_ids
=
block_numbers
*
block_size
+
local_block_offsets
slot_ids
=
tl
.
where
(
is_local
,
slot_ids
,
PAD_ID
)
tl
.
store
(
slot_mapping_ptr
+
offset
,
slot_ids
,
mask
=
offset
<
end_idx
)
...
...
vllm/v1/worker/gpu/cp_utils.py
0 → 100644
View file @
95be2a7f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
def
prepare_dcp_local_seq_lens
(
dcp_local_seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_reqs
:
int
,
dcp_size
:
int
,
dcp_rank
:
int
,
cp_interleave
:
int
,
)
->
None
:
"""Populate the persistent DCP local seq_lens buffer (CUDA graph safe)."""
if
dcp_size
==
1
:
return
max_num_reqs
=
dcp_local_seq_lens
.
shape
[
0
]
BLOCK_SIZE
=
128
num_blocks
=
triton
.
cdiv
(
max_num_reqs
,
BLOCK_SIZE
)
_dcp_local_seq_lens_kernel
[(
num_blocks
,)](
dcp_local_seq_lens
,
seq_lens
,
dcp_size
,
dcp_rank
,
cp_interleave
,
num_reqs
,
max_num_reqs
,
BLOCK_SIZE
,
)
@
triton
.
jit
def
_dcp_local_seq_lens_kernel
(
out_ptr
,
seq_lens_ptr
,
dcp_size
,
dcp_rank
,
cp_interleave
,
num_reqs
,
max_num_reqs
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
block
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
block
,
mask
=
block
<
num_reqs
)
# Distribute KV cache among different ranks, in a round-robin manner.
rounds
=
seq_lens
//
(
dcp_size
*
cp_interleave
)
remainder
=
seq_lens
%
(
dcp_size
*
cp_interleave
)
remainder
=
tl
.
maximum
(
remainder
-
dcp_rank
*
cp_interleave
,
0
)
remainder
=
tl
.
minimum
(
remainder
,
cp_interleave
)
local_seq_lens
=
rounds
*
cp_interleave
+
remainder
# For [num_reqs, max_num_reqs), pad with 0
local_seq_lens
=
tl
.
where
(
block
<
num_reqs
,
local_seq_lens
,
0
)
tl
.
store
(
out_ptr
+
block
,
local_seq_lens
,
mask
=
block
<
max_num_reqs
)
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
95be2a7f
...
...
@@ -10,7 +10,6 @@ from tqdm import tqdm
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed
import
get_dcp_group
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
...
...
@@ -18,7 +17,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_slot_mappings_by_layer
,
prepare_dcp_local_seq_lens
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
...
...
@@ -259,22 +257,8 @@ def prepare_inputs_to_capture(
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
try
:
dcp_group
=
get_dcp_group
()
dcp_world_size
=
dcp_group
.
world_size
dcp_rank
=
dcp_group
.
rank_in_group
except
AssertionError
:
dcp_world_size
=
1
dcp_rank
=
0
if
dcp_world_size
>
1
:
prepare_dcp_local_seq_lens
(
input_buffers
.
dcp_local_seq_lens
,
input_buffers
.
seq_lens
,
num_reqs
,
dcp_size
=
dcp_world_size
,
dcp_rank
=
dcp_rank
,
cp_kv_cache_interleave_size
=
block_tables
.
cp_kv_cache_interleave_size
,
)
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
=
num_tokens
input_buffers
.
dcp_local_seq_lens
[
num_reqs
:]
=
0
input_block_tables
=
[
x
[:
num_reqs
]
for
x
in
block_tables
.
input_block_tables
]
slot_mappings
=
block_tables
.
slot_mappings
[:,
:
num_tokens
]
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
95be2a7f
...
...
@@ -33,10 +33,10 @@ from vllm.v1.worker.gpu.attn_utils import (
get_kv_cache_spec
,
init_attn_backend
,
init_kv_cache
,
prepare_dcp_local_seq_lens
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
async_copy_to_gpu
from
vllm.v1.worker.gpu.cp_utils
import
prepare_dcp_local_seq_lens
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
(
get_cudagraph_and_dp_padding
,
...
...
@@ -192,6 +192,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
# Decode context parallelism.
self
.
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
use_dcp
=
self
.
dcp_size
>
1
self
.
dcp_rank
=
get_dcp_group
().
rank_in_group
if
self
.
use_dcp
else
0
self
.
cp_interleave
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
...
...
@@ -251,9 +257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_batched_tokens
=
self
.
max_num_tokens
,
max_model_len
=
self
.
max_model_len
,
device
=
self
.
device
,
cp_
kv_cache_interleave
_size
=
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
)
,
cp_
size
=
self
.
dcp
_size
,
cp_rank
=
self
.
dcp_rank
,
cp_interleave
=
self
.
cp_interleave
,
)
self
.
attn_backends
,
self
.
attn_metadata_builders
=
init_attn_backend
(
...
...
@@ -636,18 +642,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
if
dcp_size
>
1
:
if
self
.
use_dcp
:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens
(
self
.
input_buffers
.
dcp_local_seq_lens
,
seq_lens
,
self
.
input_buffers
.
seq_lens
,
num_reqs
,
dcp_size
=
dcp_size
,
dcp_rank
=
get_dcp_group
().
rank_in_group
,
cp_kv_cache_interleave_size
=
(
self
.
parallel_config
.
cp_kv_cache_interleave_size
),
self
.
dcp_size
,
self
.
dcp_rank
,
self
.
cp_interleave
,
)
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
# Prepare M-RoPE positions.
if
self
.
uses_mrope
:
...
...
@@ -696,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
kv_cache_config
=
self
.
kv_cache_config
,
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
)
input_ids
=
self
.
input_buffers
.
input_ids
[:
num_tokens_after_padding
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment