Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
483463f7
Unverified
Commit
483463f7
authored
Mar 09, 2026
by
Lucas Wilkinson
Committed by
GitHub
Mar 09, 2026
Browse files
[MRV2] Extensible CG dispatch rework (#35959)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
4e571ce6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
545 additions
and
636 deletions
+545
-636
vllm/config/compilation.py
vllm/config/compilation.py
+3
-0
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+24
-8
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+285
-316
vllm/v1/worker/gpu/dp_utils.py
vllm/v1/worker/gpu/dp_utils.py
+53
-49
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+4
-1
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+76
-57
vllm/v1/worker/gpu/model_states/default.py
vllm/v1/worker/gpu/model_states/default.py
+5
-2
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
+52
-175
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
+43
-28
No files found.
vllm/config/compilation.py
View file @
483463f7
...
...
@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
def
__str__
(
self
)
->
str
:
return
self
.
name
def
__bool__
(
self
)
->
bool
:
return
self
!=
CUDAGraphMode
.
NONE
@
config
class
PassConfig
:
...
...
vllm/v1/worker/gpu/block_table.py
View file @
483463f7
...
...
@@ -104,19 +104,24 @@ class BlockTables:
self
.
num_blocks
.
copy_to_uva
()
def
gather_block_tables
(
self
,
idx_mapping
:
torch
.
Tensor
self
,
idx_mapping
:
torch
.
Tensor
,
num_reqs_padded
:
int
,
)
->
tuple
[
torch
.
Tensor
,
...]:
num_reqs
=
idx_mapping
.
shape
[
0
]
_gather_block_tables_kernel
[(
self
.
num_kv_cache_groups
,
num_reqs
)](
# Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
_gather_block_tables_kernel
[(
self
.
num_kv_cache_groups
,
num_reqs_padded
)](
idx_mapping
,
self
.
block_table_ptrs
,
self
.
input_block_table_ptrs
,
self
.
block_table_strides
,
self
.
num_blocks
.
gpu
,
self
.
num_blocks
.
gpu
.
stride
(
0
),
num_reqs
,
self
.
input_block_tables
[
0
].
shape
[
1
],
# max_num_blocks
BLOCK_SIZE
=
1024
,
# type: ignore
)
return
tuple
(
b
lock_table
[:
num_reqs
]
for
block_table
in
self
.
input_block_tables
)
return
tuple
(
b
t
[:
num_reqs_padded
]
for
bt
in
self
.
input_block_tables
)
def
get_dummy_block_tables
(
self
,
num_reqs
:
int
)
->
tuple
[
torch
.
Tensor
,
...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
...
...
@@ -130,6 +135,7 @@ class BlockTables:
idx_mapping
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
num_tokens_padded
:
int
,
)
->
torch
.
Tensor
:
num_reqs
=
idx_mapping
.
shape
[
0
]
num_tokens
=
positions
.
shape
[
0
]
...
...
@@ -151,7 +157,7 @@ class BlockTables:
PAD_ID
=
PAD_SLOT_ID
,
TRITON_BLOCK_SIZE
=
1024
,
# type: ignore
)
return
self
.
slot_mappings
[:,
:
num_tokens
]
return
self
.
slot_mappings
[:,
:
num_tokens
_padded
]
def
get_dummy_slot_mappings
(
self
,
num_tokens
:
int
)
->
torch
.
Tensor
:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
...
...
@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides
,
# [num_kv_cache_groups]
num_blocks_ptr
,
# [num_kv_cache_groups, max_num_reqs]
num_blocks_stride
,
num_reqs
,
# actual number of requests (for padding)
max_num_blocks
,
# stride for zeroing padded rows
BLOCK_SIZE
:
tl
.
constexpr
,
):
# kv cache group id
group_id
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
1
)
req_idx
=
tl
.
load
(
batch_idx_to_req_idx
+
batch_idx
)
stride
=
tl
.
load
(
block_table_strides
+
group_id
)
dst_block_table_ptr
=
_load_ptr
(
dst_block_table_ptrs
+
group_id
,
tl
.
int32
)
dst_row_ptr
=
dst_block_table_ptr
+
batch_idx
*
stride
if
batch_idx
>=
num_reqs
:
# Zero out padded rows.
for
i
in
tl
.
range
(
0
,
max_num_blocks
,
BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
dst_row_ptr
+
offset
,
0
,
mask
=
offset
<
max_num_blocks
)
return
req_idx
=
tl
.
load
(
batch_idx_to_req_idx
+
batch_idx
)
group_num_blocks_ptr
=
num_blocks_ptr
+
group_id
*
num_blocks_stride
num_blocks
=
tl
.
load
(
group_num_blocks_ptr
+
req_idx
)
stride
=
tl
.
load
(
block_table_strides
+
group_id
)
src_block_table_ptr
=
_load_ptr
(
src_block_table_ptrs
+
group_id
,
tl
.
int32
)
src_row_ptr
=
src_block_table_ptr
+
req_idx
*
stride
dst_block_table_ptr
=
_load_ptr
(
dst_block_table_ptrs
+
group_id
,
tl
.
int32
)
dst_row_ptr
=
dst_block_table_ptr
+
batch_idx
*
stride
for
i
in
tl
.
range
(
0
,
num_blocks
,
BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
483463f7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
typing
import
Any
import
torch
...
...
@@ -11,235 +13,260 @@ from vllm.config import VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.offloader.base
import
get_offloader
from
vllm.
utils.math_utils
import
cdiv
from
vllm.
platforms
import
current_platform
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_slot_mappings_by_layer
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cp_utils
import
prepare_dcp_local_seq_lens
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.utils
import
AttentionGroup
logger
=
init_logger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
BatchExecutionDescriptor
:
"""Describes the shape of the batch and CG mode to run; this is used to make shape
matches between the capture and runtime."""
cg_mode
:
CUDAGraphMode
num_tokens
:
int
num_reqs
:
int
|
None
# None means no request padding is needed (PIECEWISE graphs)
uniform_token_count
:
int
|
None
=
None
def
_is_compatible
(
desc
:
BatchExecutionDescriptor
,
num_reqs
:
int
,
num_tokens
:
int
,
uniform_token_count
:
int
|
None
,
)
->
bool
:
# desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
# desc.num_reqs=None means no request padding needed (PIECEWISE)
return
(
(
desc
.
uniform_token_count
is
None
or
desc
.
uniform_token_count
==
uniform_token_count
)
and
(
desc
.
num_reqs
is
None
or
desc
.
num_reqs
>=
num_reqs
)
and
desc
.
num_tokens
>=
num_tokens
)
def
get_uniform_token_count
(
num_reqs
:
int
,
num_tokens
:
int
,
max_query_len
:
int
,
)
->
int
|
None
:
"""
Return the uniform token count if batch is uniform, else None.
A batch is uniform if all requests have the same number of tokens.
"""
if
(
max_query_len
==
num_tokens
//
num_reqs
)
and
(
num_tokens
==
max_query_len
*
num_reqs
):
return
max_query_len
return
None
class
CudaGraphManager
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
use_aux_hidden_state_outputs
:
bool
,
device
:
torch
.
device
,
cudagraph_mode
:
CUDAGraphMode
,
decode_query_len
:
int
,
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
use_aux_hidden_state_outputs
=
use_aux_hidden_state_outputs
self
.
device
=
device
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
self
.
cudagraph_mode
=
cudagraph_mode
self
.
decode_query_len
=
decode_query_len
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
uniform_decode_query_len
=
1
spec_config
=
vllm_config
.
speculative_config
if
spec_config
is
not
None
:
self
.
uniform_decode_query_len
+=
spec_config
.
num_speculative_tokens
self
.
graphs
:
dict
[
BatchExecutionDescriptor
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
pool
=
current_platform
.
get_global_graph_pool
()
if
cudagraph_mode
else
None
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
self
.
_graphs_captured
=
False
self
.
_candidates
:
list
[
list
[
BatchExecutionDescriptor
]]
=
[]
self
.
_capture_descs
:
dict
[
CUDAGraphMode
,
list
[
BatchExecutionDescriptor
]]
=
{}
self
.
_init_candidates
()
use_uniform_decode_cudagraph
=
(
self
.
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
self
.
cudagraph_mode
.
separate_routine
()
)
self
.
cudagraph_sizes
,
self
.
uniform_decode_cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
self
.
uniform_decode_query_len
,
use_uniform_decode_cudagraph
,
)
def
_init_candidates
(
self
)
->
None
:
"""Build priority-ordered candidate lists for each token count."""
capture_sizes
=
self
.
compilation_config
.
cudagraph_capture_sizes
if
not
(
self
.
cudagraph_mode
and
capture_sizes
):
return
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
pool
=
None
if
self
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
:
self
.
pool
=
torch
.
cuda
.
graph_pool_handle
()
self
.
hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
aux_hidden_states
:
list
[
torch
.
Tensor
]
=
[]
capture_sizes
=
sorted
(
capture_sizes
)
max_decode_tokens
=
self
.
max_num_reqs
*
self
.
decode_query_len
decode_mode
=
self
.
cudagraph_mode
.
decode_mode
()
mixed_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
separate_decode_routine
=
self
.
cudagraph_mode
.
separate_routine
()
descs_by_token_count
=
defaultdict
(
list
)
descs_by_mode
=
defaultdict
(
list
)
for
num_tokens
in
capture_sizes
:
# Capture uniform decode specfifc graphs if required
# (i.e. separate decode routine)
if
(
separate_decode_routine
and
decode_mode
and
self
.
decode_query_len
<=
num_tokens
<=
max_decode_tokens
):
desc
=
BatchExecutionDescriptor
(
cg_mode
=
decode_mode
,
num_tokens
=
num_tokens
,
num_reqs
=
num_tokens
//
self
.
decode_query_len
,
uniform_token_count
=
self
.
decode_query_len
,
)
descs_by_mode
[
decode_mode
].
append
(
desc
)
descs_by_token_count
[
num_tokens
].
append
(
desc
)
if
mixed_mode
:
# for PIECEWISE graphs there is no limit on requests when replaying
# i.e. no request padding is needed
# so we leave it as None
num_reqs
=
(
min
(
num_tokens
,
self
.
max_num_reqs
)
if
mixed_mode
==
CUDAGraphMode
.
FULL
else
None
)
desc
=
BatchExecutionDescriptor
(
cg_mode
=
mixed_mode
,
num_tokens
=
num_tokens
,
num_reqs
=
num_reqs
,
)
descs_by_mode
[
mixed_mode
].
append
(
desc
)
descs_by_token_count
[
num_tokens
].
append
(
desc
)
if
not
descs_by_token_count
:
return
sorted_padded
=
sorted
(
descs_by_token_count
.
keys
())
self
.
_candidates
=
[[]
for
_
in
range
(
sorted_padded
[
-
1
]
+
1
)]
current_range_start
=
0
for
cg_size
in
sorted_padded
:
for
i
in
range
(
current_range_start
,
cg_size
+
1
):
self
.
_candidates
[
i
]
=
descs_by_token_count
[
cg_size
]
current_range_start
=
cg_size
+
1
for
mode
,
descs
in
descs_by_mode
.
items
():
descs
.
sort
(
key
=
lambda
d
:
d
.
num_tokens
,
reverse
=
True
)
self
.
_capture_descs
[
mode
]
=
descs
def
needs_capture
(
self
)
->
bool
:
return
len
(
self
.
cudagraph_sizes
)
>
0
def
get_cudagraph_size
(
self
,
num_tokens
:
int
,
uniform_decode
:
bool
=
False
)
->
int
|
None
:
if
uniform_decode
and
self
.
uniform_decode_cudagraph_sizes
:
return
self
.
uniform_decode_cudagraph_sizes
.
get
(
num_tokens
)
return
self
.
cudagraph_sizes
.
get
(
num_tokens
)
return
len
(
self
.
_capture_descs
)
>
0
def
capture_graph
(
@
torch
.
inference_mode
()
def
capture
(
self
,
num_tokens
:
int
,
capture_cg_mode
:
CUDAGraphMode
,
model
:
nn
.
Module
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
uniform_decode
:
bool
=
False
,
create_forward_fn
:
Callable
[
[
BatchExecutionDescriptor
],
Callable
[[
CUDAGraphMode
],
None
]
],
progress_bar_desc
:
str
=
"Capturing CUDA graphs"
,
)
->
None
:
# select and check capture function
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
f
"Invalid capture_cudagraph_mode for capture:
{
capture_cg_mode
}
"
)
if
capture_cg_mode
==
CUDAGraphMode
.
PIECEWISE
:
capture_fn
=
self
.
_capture_piecewise_graph
else
:
capture_fn
=
self
.
_capture_full_graph
# prepare inputs
if
uniform_decode
:
num_reqs
=
min
(
cdiv
(
num_tokens
,
self
.
uniform_decode_query_len
),
self
.
max_num_reqs
,
)
else
:
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
model_inputs
=
{
"input_ids"
:
input_buffers
.
input_ids
[:
num_tokens
],
"positions"
:
input_buffers
.
positions
[:
num_tokens
],
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**
model_state
.
prepare_dummy_inputs
(
num_reqs
,
num_tokens
),
}
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
model_state
,
input_buffers
,
block_tables
,
attn_groups
,
kv_cache_config
,
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
# Warm up.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
):
model_output
=
model
(
**
model_inputs
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Allocate output buffers if not already done.
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
self
.
use_aux_hidden_state_outputs
and
not
self
.
aux_hidden_states
:
self
.
aux_hidden_states
=
[
torch
.
empty_like
(
x
)
for
x
in
aux_hidden_states
]
capture_fn
(
num_tokens
=
num_tokens
,
num_reqs
=
num_reqs
,
model
=
model
,
model_inputs
=
model_inputs
,
num_tokens_across_dp
=
num_tokens_across_dp
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings
,
has_lora
=
has_lora
,
)
def
_capture_full_graph
(
"""Capture CUDA graphs.
Args:
create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
returns a function that runs forward with a given CUDAGraphMode.
"""
with
graph_capture
(
device
=
self
.
device
):
# Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
# activations so FULL activations should fit in already allocated
# buffers in the graph pool.
for
mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]:
if
mode
not
in
self
.
_capture_descs
:
continue
descs
=
self
.
_capture_descs
[
mode
]
if
is_global_first_rank
():
descs
=
tqdm
(
descs
,
desc
=
f
"
{
progress_bar_desc
}
(
{
mode
.
name
}
)"
)
for
desc
in
descs
:
# Prepare inputs and get forward function
forward_fn
=
create_forward_fn
(
desc
)
# Warmup
forward_fn
(
CUDAGraphMode
.
NONE
)
# Capture
logger
.
debug
(
"CG Capture: mode=%s, batch_desc=%s"
,
desc
.
cg_mode
.
name
,
desc
)
if
desc
.
cg_mode
==
CUDAGraphMode
.
PIECEWISE
:
forward_fn
(
CUDAGraphMode
.
PIECEWISE
)
else
:
assert
desc
not
in
self
.
graphs
,
(
f
"Graph already captured for
{
desc
}
"
)
graph
=
torch
.
cuda
.
CUDAGraph
()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader
().
sync_prev_onload
()
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
forward_fn
(
CUDAGraphMode
.
NONE
)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader
().
join_after_forward
()
self
.
graphs
[
desc
]
=
graph
self
.
_graphs_captured
=
True
def
dispatch
(
self
,
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
model_inputs
:
dict
[
str
,
torch
.
Tensor
|
None
]
,
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
has_lora
:
bool
=
False
,
)
->
None
:
assert
attn_metadata
is
not
None
# Capture the graph.
assert
num_tokens
not
in
self
.
graph
s
graph
=
torch
.
cuda
.
CUDAGraph
(
)
num_tokens
:
int
,
uniform_token_count
:
int
|
None
,
)
->
BatchExecutionDescriptor
:
"""Find matching cudagraph descriptor from priority-ordered candidates."""
if
self
.
_graphs_captured
and
0
<
num_tokens
<
len
(
self
.
_candidates
):
for
desc
in
self
.
_candidates
[
num_tokens
]:
if
_is_compatible
(
desc
,
num_reqs
,
num_tokens
,
uniform_token_count
)
:
return
desc
return
BatchExecutionDescriptor
(
cg_mode
=
CUDAGraphMode
.
NONE
,
num_tokens
=
num_tokens
,
num_reqs
=
num_req
s
)
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
def
run_fullgraph
(
self
,
desc
:
BatchExecutionDescriptor
):
"""Replay a captured FULL cudagraph."""
assert
desc
.
cg_mode
==
CUDAGraphMode
.
FULL
,
(
f
"Expected FULL mode, got
{
desc
.
cg_mode
}
"
)
assert
desc
in
self
.
graphs
,
f
"No cudagraph for
{
desc
}
"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader
().
sync_prev_onload
()
self
.
graphs
[
desc
].
replay
()
class
ModelCudaGraphManager
(
CudaGraphManager
):
"""CudaGraphManager with model-specific capture and hidden state management."""
with
(
set_forward_context
(
attn_metadata
=
attn_metadata
,
vllm_config
=
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
),
torch
.
cuda
.
graph
(
graph
,
self
.
pool
),
):
model_output
=
model
(
**
model_inputs
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader
().
join_after_forward
()
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
None
# Copy outputs to the output buffers.
assert
self
.
hidden_states
is
not
None
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
if
self
.
use_aux_hidden_state_outputs
:
for
i
,
aux_hidden
in
enumerate
(
aux_hidden_states
):
self
.
aux_hidden_states
[
i
][:
num_tokens
]
=
aux_hidden
self
.
graphs
[
num_tokens
]
=
graph
def
_capture_piecewise_graph
(
def
__init__
(
self
,
num_tokens
:
int
,
num_reqs
:
int
,
model
:
nn
.
Module
,
model_inputs
:
dict
[
str
,
torch
.
Tensor
|
None
],
num_tokens_across_dp
:
torch
.
Tensor
,
attn_metadata
:
dict
[
str
,
Any
]
|
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
,
has_lora
:
bool
=
False
,
)
->
None
:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
num_tokens
,
has_lora
=
has_lora
)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with
set_forward_context
(
attn_metadata
=
None
,
# piecewise no need attn_metadata
vllm_config
=
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings
,
):
model
(
**
model_inputs
)
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
cudagraph_mode
:
CUDAGraphMode
,
decode_query_len
:
int
,
):
super
().
__init__
(
vllm_config
,
device
,
cudagraph_mode
,
decode_query_len
)
self
.
hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
aux_hidden_states
:
list
[
torch
.
Tensor
]
=
[]
self
.
use_aux_hidden_state_outputs
=
False
@
torch
.
inference_mode
()
def
capture
(
self
,
model
:
nn
.
Module
,
...
...
@@ -249,139 +276,81 @@ class CudaGraphManager:
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
has_lora
:
bool
=
False
,
use_aux_hidden_state_outputs
:
bool
=
False
,
progress_bar_desc
:
str
=
"Capturing CUDA graphs"
,
)
->
None
:
common_kwargs
=
dict
(
device
=
self
.
device
,
capture_fn
=
self
.
capture_graph
,
model
=
model
,
model_state
=
model_state
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
attn_groups
=
attn_groups
,
kv_cache_config
=
kv_cache_config
,
has_lora
=
has_lora
,
)
"""Capture CUDA graphs for model forward pass."""
self
.
use_aux_hidden_state_outputs
=
use_aux_hidden_state_outputs
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
if
mixed_mode
!=
CUDAGraphMode
.
NONE
:
capture_graphs
(
cudagraph_sizes
=
self
.
cudagraph_sizes
,
capture_cudagraph_mode
=
mixed_mode
,
desc
=
f
"Capturing CUDA graphs (mixed,
{
mixed_mode
.
name
}
)"
,
uniform_decode
=
False
,
**
common_kwargs
,
def
create_forward_fn
(
desc
:
BatchExecutionDescriptor
,
)
->
Callable
[[
CUDAGraphMode
],
None
]
:
num_tokens
=
desc
.
num_tokens
num_reqs
=
desc
.
num_reqs
or
min
(
num_tokens
,
self
.
max_num_reqs
)
num_tokens_across_dp
=
(
torch
.
full
((
self
.
dp_size
,),
num_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
if
self
.
dp_size
>
1
else
None
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if
self
.
uniform_decode_cudagraph_sizes
:
capture_graphs
(
cudagraph_sizes
=
self
.
uniform_decode_cudagraph_sizes
,
capture_cudagraph_mode
=
CUDAGraphMode
.
FULL
,
desc
=
"Capturing CUDA graphs (decode, FULL)"
,
uniform_decode
=
True
,
**
common_kwargs
,
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
model_state
,
input_buffers
,
block_tables
,
attn_groups
,
kv_cache_config
,
)
def
get_cudagraph_runtime_mode
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
max_query_len
:
int
)
->
tuple
[
CUDAGraphMode
,
int
|
None
]:
is_uniform_decode
=
(
max_query_len
==
self
.
uniform_decode_query_len
)
and
(
num_tokens
==
max_query_len
*
num_reqs
)
cudagraph_size
=
self
.
get_cudagraph_size
(
num_tokens
,
is_uniform_decode
)
if
cudagraph_size
is
None
:
cudagraph_mode
=
CUDAGraphMode
.
NONE
elif
is_uniform_decode
:
cudagraph_mode
=
self
.
cudagraph_mode
.
decode_mode
()
else
:
cudagraph_mode
=
self
.
cudagraph_mode
.
mixed_mode
()
if
(
cudagraph_mode
==
CUDAGraphMode
.
FULL
and
cudagraph_size
is
not
None
and
cudagraph_size
not
in
self
.
graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode
=
CUDAGraphMode
.
NONE
cudagraph_size
=
None
return
cudagraph_mode
,
cudagraph_size
def
forward_fn
(
cg_mode
:
CUDAGraphMode
)
->
None
:
batch_descriptor
=
(
BatchDescriptor
(
num_tokens
=
num_tokens
)
if
cg_mode
==
CUDAGraphMode
.
PIECEWISE
else
None
)
with
set_forward_context
(
attn_metadata
if
cg_mode
!=
CUDAGraphMode
.
PIECEWISE
else
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
,
cudagraph_runtime_mode
=
cg_mode
,
num_tokens_across_dp
=
num_tokens_across_dp
,
slot_mapping
=
slot_mappings
,
batch_descriptor
=
batch_descriptor
,
):
model_inputs
=
{
"input_ids"
:
input_buffers
.
input_ids
[:
num_tokens
],
"positions"
:
input_buffers
.
positions
[:
num_tokens
],
}
model_output
=
model
(
**
model_inputs
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
aux_hidden_states
=
[]
if
self
.
hidden_states
is
None
:
self
.
hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
self
.
use_aux_hidden_state_outputs
and
not
self
.
aux_hidden_states
:
self
.
aux_hidden_states
=
[
torch
.
empty_like
(
x
)
for
x
in
aux_hidden_states
]
self
.
hidden_states
[:
num_tokens
]
=
hidden_states
for
i
,
aux
in
enumerate
(
aux_hidden_states
):
self
.
aux_hidden_states
[
i
][:
num_tokens
]
=
aux
return
forward_fn
super
().
capture
(
create_forward_fn
,
progress_bar_desc
)
def
run_fullgraph
(
self
,
num_tokens
:
int
self
,
desc
:
BatchExecutionDescriptor
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
assert
num_tokens
in
self
.
graphs
,
f
"No cudagraph for
{
num_tokens
}
tokens"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader
().
sync_prev_onload
()
self
.
graphs
[
num_tokens
].
replay
()
"""Replay a captured FULL cudagraph and return hidden states."""
super
().
run_fullgraph
(
desc
)
assert
self
.
hidden_states
is
not
None
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
hidden_states
=
self
.
hidden_states
[:
desc
.
num_tokens
]
if
not
self
.
use_aux_hidden_state_outputs
:
return
hidden_states
return
hidden_states
,
[
x
[:
num_tokens
]
for
x
in
self
.
aux_hidden_states
]
def
get_cudagraph_sizes
(
capture_sizes
:
list
[
int
]
|
None
,
max_num_reqs
:
int
,
max_num_tokens
:
int
,
cudagraph_mode
:
CUDAGraphMode
,
uniform_decode_query_len
:
int
=
1
,
uniform_decode_cudagraph
:
bool
=
False
,
)
->
tuple
[
dict
[
int
,
int
],
dict
[
int
,
int
]]:
# Support both FULL and PIECEWISE cudagraph modes
if
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
return
{},
{}
if
not
capture_sizes
:
return
{},
{}
capture_sizes
=
sorted
(
capture_sizes
)
if
not
capture_sizes
:
return
{},
{}
cudagraph_sizes
:
dict
[
int
,
int
]
=
{}
for
i
in
range
(
1
,
capture_sizes
[
-
1
]
+
1
):
for
x
in
capture_sizes
:
if
i
<=
x
:
cudagraph_sizes
[
i
]
=
x
break
uniform_decode_cudagraph_sizes
:
dict
[
int
,
int
]
=
{}
if
uniform_decode_cudagraph
:
max_num_tokens
=
max_num_reqs
*
uniform_decode_query_len
uniform_decode_cudagraph_sizes
=
{
k
:
v
for
k
,
v
in
cudagraph_sizes
.
items
()
if
v
<=
max_num_tokens
and
v
>=
uniform_decode_query_len
}
return
cudagraph_sizes
,
uniform_decode_cudagraph_sizes
def
capture_graphs
(
cudagraph_sizes
:
dict
[
int
,
int
],
device
:
torch
.
device
,
capture_fn
:
Callable
,
capture_cudagraph_mode
:
CUDAGraphMode
,
desc
:
str
=
"Capturing CUDA graphs"
,
**
capture_kwargs
,
)
->
None
:
# Capture larger graphs first.
sizes_to_capture
=
sorted
(
set
(
cudagraph_sizes
.
values
()),
reverse
=
True
)
if
is_global_first_rank
():
sizes_to_capture
=
tqdm
(
sizes_to_capture
,
desc
=
desc
)
with
graph_capture
(
device
=
device
):
for
size
in
sizes_to_capture
:
capture_fn
(
size
,
capture_cudagraph_mode
,
**
capture_kwargs
)
return
hidden_states
,
[
x
[:
desc
.
num_tokens
]
for
x
in
self
.
aux_hidden_states
]
def
prepare_inputs_to_capture
(
...
...
vllm/v1/worker/gpu/dp_utils.py
View file @
483463f7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
import
torch.distributed
as
dist
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
BatchExecutionDescriptor
,
CudaGraphManager
,
)
def
make_num_tokens_across_dp
(
dp_size
:
int
,
num_tokens
:
int
)
->
torch
.
Tensor
|
None
:
...
...
@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
return
torch
.
full
((
dp_size
,),
num_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
get_batch_metadata_across_dp
(
def
sync_cudagraph_and_dp_padding
(
cudagraph_manager
:
CudaGraphManager
,
desired_batch_desc
:
BatchExecutionDescriptor
,
num_tokens
:
int
,
cudagraph_size
:
int
,
cudagraph_runtime_mode
:
int
,
num_reqs
:
int
,
uniform_token_count
:
int
|
None
,
dp_size
:
int
,
dp_rank
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
dp_size
>
1
# Use CPU group to avoid CPU-GPU synchronization.
)
->
tuple
[
BatchExecutionDescriptor
,
torch
.
Tensor
|
None
]:
"""
Coordinates the batch descriptor and DP padding across all ranks.
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert
dp_size
>
1
,
"DP size must be greater than 1"
group
=
get_dp_group
().
cpu_group
tensor
=
torch
.
zeros
(
3
,
dp_size
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor
[
0
][
dp_rank
]
=
num_tokens
tensor
[
1
][
dp_rank
]
=
cudagraph_siz
e
tensor
[
2
][
dp_rank
]
=
cudagraph_runtime_mode
tensor
[
1
][
dp_rank
]
=
desired_batch_desc
.
cg_mode
.
valu
e
tensor
[
2
][
dp_rank
]
=
uniform_token_count
or
0
# (0 means None)
dist
.
all_reduce
(
tensor
,
group
=
group
)
return
tensor
[
0
],
tensor
[
1
],
tensor
[
2
]
num_tokens_across_dp
=
tensor
[
0
]
cg_mode_across_dp
=
tensor
[
1
]
uniform_token_counts_across_dp
=
tensor
[
2
]
def
get_cudagraph_and_dp_padding
(
num_tokens
:
int
,
cudagraph_size
:
int
|
None
,
cudagraph_runtime_mode
:
int
,
dp_size
:
int
,
dp_rank
:
int
,
)
->
tuple
[
int
,
torch
.
Tensor
|
None
,
int
]:
if
dp_size
==
1
:
if
cudagraph_size
is
not
None
:
return
cudagraph_size
,
None
,
cudagraph_runtime_mode
else
:
return
num_tokens
,
None
,
cudagraph_runtime_mode
if
torch
.
all
(
num_tokens_across_dp
==
0
).
item
():
synced_desc
=
BatchExecutionDescriptor
(
cg_mode
=
CUDAGraphMode
.
NONE
,
num_tokens
=
0
,
num_reqs
=
0
)
return
synced_desc
,
None
# Convert None to -1 for sync (indicates no cudagraph available)
if
num_tokens
==
0
:
cudagraph_size
=
0
elif
cudagraph_size
is
None
:
cudagraph_size
=
-
1
synced_cg_mode
=
CUDAGraphMode
(
int
(
cg_mode_across_dp
.
min
().
item
()))
num_tokens_across_dp
,
cudagraph_size_across_dp
,
cudagraph_mode_across_dp
=
(
get_batch_metadata_across_dp
(
num_tokens
,
cudagraph_size
,
cudagraph_runtime_mode
,
dp_size
,
dp_rank
)
# If any rank wants to run eager, all ranks run eager
if
synced_cg_mode
==
CUDAGraphMode
.
NONE
:
return
BatchExecutionDescriptor
(
cg_mode
=
CUDAGraphMode
.
NONE
,
num_tokens
=
num_tokens
,
num_reqs
=
num_reqs
,
),
num_tokens_across_dp
synced_num_tokens
=
int
(
num_tokens_across_dp
.
max
().
item
())
synced_uniform_token_count
=
uniform_token_counts_across_dp
[
0
]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
if
synced_uniform_token_count
==
0
or
not
torch
.
all
(
uniform_token_counts_across_dp
==
synced_uniform_token_count
):
synced_uniform_token_count
=
None
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
# so we don't perform request padding for PIECEWISE graphs
synced_desc
=
cudagraph_manager
.
dispatch
(
num_reqs
,
synced_num_tokens
,
synced_uniform_token_count
)
if
torch
.
all
(
num_tokens_across_dp
==
0
).
item
():
# All ranks have zero tokens to run.
return
0
,
None
,
0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode
=
int
(
cudagraph_mode_across_dp
.
min
().
item
())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph
=
torch
.
all
(
cudagraph_size_across_dp
!=
-
1
).
item
()
# Update num_tokens_across_dp to reflect padded size.
num_tokens_across_dp
[:]
=
synced_desc
.
num_tokens
if
synced_cudagraph_mode
!=
0
and
all_have_cudagraph
:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size
=
int
(
cudagraph_size_across_dp
.
max
().
item
())
num_tokens_across_dp
[:]
=
max_cudagraph_size
return
max_cudagraph_size
,
num_tokens_across_dp
,
synced_cudagraph_mode
else
:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode
=
0
num_tokens_across_dp
=
torch
.
clamp
(
num_tokens_across_dp
,
min
=
1
)
num_tokens_after_padding
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
return
num_tokens_after_padding
,
num_tokens_across_dp
,
synced_cudagraph_mode
return
synced_desc
,
num_tokens_across_dp
vllm/v1/worker/gpu/input_batch.py
View file @
483463f7
...
...
@@ -37,6 +37,7 @@ class InputBatch:
# batch_idx -> req_id
req_ids
:
list
[
str
]
num_reqs
:
int
num_reqs_after_padding
:
int
# batch_idx -> req_state_idx
idx_mapping
:
torch
.
Tensor
...
...
@@ -123,6 +124,7 @@ class InputBatch:
return
cls
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs_after_padding
=
num_reqs
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
...
...
@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
cu_num_logits
:
torch
.
Tensor
,
num_logits
:
int
,
)
->
torch
.
Tensor
:
num_reqs
=
seq_lens
.
shape
[
0
]
# use idx_mapping.shape[0] for actual request count
num_reqs
=
idx_mapping
.
shape
[
0
]
num_speculative_steps
=
draft_tokens
.
shape
[
-
1
]
logits_indices
=
torch
.
empty
(
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
483463f7
...
...
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
SupportedTask
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
...
@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
async_copy_to_gpu
from
vllm.v1.worker.gpu.cp_utils
import
prepare_dcp_local_seq_lens
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
get_cudagraph_and_dp_padding
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
BatchExecutionDescriptor
,
ModelCudaGraphManager
,
get_uniform_token_count
,
)
from
vllm.v1.worker.gpu.dp_utils
import
sync_cudagraph_and_dp_padding
from
vllm.v1.worker.gpu.input_batch
import
(
InputBatch
,
InputBuffers
,
...
...
@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
# Data parallelism.
self
.
dp_size
=
self
.
parallel_config
.
data_parallel_size
self
.
dp_rank
=
self
.
parallel_config
.
data_parallel_rank
# Decode context parallelism.
self
.
dcp_size
=
self
.
parallel_config
.
decode_context_parallel_size
self
.
use_dcp
=
self
.
dcp_size
>
1
...
...
@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
prompt_logprobs_worker
=
PromptLogprobsWorker
(
self
.
max_num_reqs
)
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
decode_query_len
=
self
.
num_speculative_steps
+
1
self
.
cudagraph_manager
=
ModelCudaGraphManager
(
self
.
vllm_config
,
self
.
use_aux_hidden_state_outputs
,
self
.
device
,
self
.
compilation_config
.
cudagraph_mode
,
decode_query_len
=
self
.
decode_query_len
,
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
...
...
@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
]:
# Create a dummy scheduler output.
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
if
uniform_decode
:
#
Align tokens to uniform_decode_query_len for cudagraph
#
compatibility across DP ranks.
query_len
=
self
.
cudagraph_manager
.
uniform_decode_query_len
num_reqs
=
min
(
cdiv
(
num_tokens
,
query_len
),
self
.
max_num_reqs
)
num_tokens
=
num_reqs
*
query_len
num_
tokens_per_request
=
[
query_len
]
*
num_reqs
else
:
num_reqs
=
min
(
num_tokens
,
self
.
max_
num_reqs
)
num_tokens_per_request
=
[
num_tokens
//
num_reqs
]
*
num_reqs
num_tokens_per_request
[
-
1
]
+=
num_tokens
%
num_reqs
#
HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
#
and for spec-decode with MTP we want to make sure the dummy runs use
# 1+num_speculative_tokens we use max here, this will likely be eventually
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens
=
max
(
num_tokens
,
self
.
decode_
query_len
)
num_
reqs
=
num_tokens
//
self
.
decode_query_len
assert
num_tokens
%
self
.
decode_query_len
==
0
num_tokens_per_request
=
[
num_tokens
//
num_reqs
]
*
num_reqs
num_tokens_per_request
[
-
1
]
+
=
num_tokens
%
num_reqs
assert
sum
(
num_tokens_per_request
)
==
num_tokens
num_scheduled_tokens
=
{
f
"_dummy_req_
{
i
}
"
:
n
for
i
,
n
in
enumerate
(
num_tokens_per_request
)
...
...
@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with
self
.
maybe_setup_dummy_loras
(
self
.
lora_config
):
self
.
cudagraph_manager
.
capture
(
model
=
self
.
model
,
model_state
=
self
.
model_state
,
input_buffers
=
self
.
input_buffers
,
block_tables
=
self
.
block_tables
,
attn_groups
=
self
.
attn_groups
,
kv_cache_config
=
self
.
kv_cache_config
,
self
.
model
,
self
.
model_state
,
self
.
input_buffers
,
self
.
block_tables
,
self
.
attn_groups
,
self
.
kv_cache_config
,
has_lora
=
self
.
lora_config
is
not
None
,
use_aux_hidden_state_outputs
=
self
.
use_aux_hidden_state_outputs
,
)
if
self
.
speculator
is
not
None
:
self
.
speculator
.
capture_model
()
...
...
@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
def
prepare_inputs
(
self
,
scheduler_output
:
SchedulerOutput
,
num_tokens_after_padding
:
int
self
,
scheduler_output
:
SchedulerOutput
,
batch_desc
:
BatchExecutionDescriptor
)
->
InputBatch
:
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_tokens_after_padding
=
batch_desc
.
num_tokens
assert
num_tokens
>
0
num_tokens_per_req
=
scheduler_output
.
num_scheduled_tokens
num_reqs
=
len
(
num_tokens_per_req
)
...
...
@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded
=
batch_desc
.
num_reqs
or
num_reqs
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
np
.
cumsum
(
num_scheduled_tokens
,
out
=
query_start_loc_np
[
1
:
num_reqs
+
1
])
...
...
@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
_padded
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
_padded
+
1
]
# Get prefill tokens if any.
if
self
.
req_states
.
any_prefills
(
idx_mapping_np
):
...
...
@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_buffers
.
positions
,
self
.
input_buffers
.
seq_lens
,
)
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
_padded
]
dcp_local_seq_lens
=
None
if
self
.
use_dcp
:
...
...
@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
dcp_rank
,
self
.
cp_interleave
,
)
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
]
dcp_local_seq_lens
=
self
.
input_buffers
.
dcp_local_seq_lens
[:
num_reqs
_padded
]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
...
...
@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
InputBatch
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs_after_padding
=
num_reqs_padded
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
...
...
@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
prepare_attn
(
self
,
input_batch
:
InputBatch
)
->
tuple
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
input_batch
.
idx_mapping
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
# Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables
=
self
.
block_tables
.
gather_block_tables
(
input_batch
.
idx_mapping
,
num_reqs_padded
=
input_batch
.
num_reqs_after_padding
,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
input_batch
.
idx_mapping
,
input_batch
.
query_start_loc
,
input_batch
.
positions
,
num_tokens_padded
=
input_batch
.
num_tokens_after_padding
,
)
return
block_tables
,
slot_mappings
...
...
@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
return
empty_output
# Get local cudagraph mode and size.
local_cudagraph_mode
,
local_cudagraph_size
=
(
self
.
cudagraph_manager
.
get_cudagraph_runtime_mode
(
num_reqs
=
len
(
scheduler_output
.
num_scheduled_tokens
),
num_tokens
=
scheduler_output
.
total_num_scheduled_tokens
,
max_query_len
=
max
(
scheduler_output
.
num_scheduled_tokens
.
values
()),
)
# Get batch descriptor and sync across DP ranks.
num_reqs
=
len
(
scheduler_output
.
num_scheduled_tokens
)
num_toks
=
scheduler_output
.
total_num_scheduled_tokens
max_query_len
=
max
(
scheduler_output
.
num_scheduled_tokens
.
values
())
uniform_tok_count
=
get_uniform_token_count
(
num_reqs
,
num_toks
,
max_query_len
)
batch_desc
=
self
.
cudagraph_manager
.
dispatch
(
num_reqs
,
num_toks
,
uniform_tok_count
)
num_tokens_across_dp
=
None
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
get_cudagraph_and_dp_padding
(
scheduler_output
.
total_num_scheduled_tokens
,
local_cudagraph_size
,
local_cudagraph_mode
.
value
,
self
.
parallel_config
.
data_parallel_size
,
self
.
parallel_config
.
data_parallel_rank
,
if
self
.
dp_size
>
1
:
batch_desc
,
num_tokens_across_dp
=
sync_cudagraph_and_dp_padding
(
self
.
cudagraph_manager
,
batch_desc
,
num_toks
,
num_reqs
,
uniform_tok_count
,
self
.
dp_size
,
self
.
dp_rank
,
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
(
synced_cudagraph_mode
)
if
num_tokens_after_padding
==
0
:
if
batch_desc
.
num_tokens
==
0
:
# All DP ranks have zero tokens to run.
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
return
empty_output
...
...
@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
not
dummy_run
:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch
=
self
.
prepare_inputs
(
scheduler_output
,
num_tokens_after_padding
)
input_batch
=
self
.
prepare_inputs
(
scheduler_output
,
batch_desc
)
block_tables
,
slot_mappings
=
self
.
prepare_attn
(
input_batch
)
if
self
.
lora_config
:
...
...
@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
_set_active_loras
(
*
lora_inputs
)
else
:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
,
num_tokens_after_padding
,
self
.
input_buffers
batch_desc
.
num_reqs
or
num_reqs
,
batch_desc
.
num_tokens
,
self
.
input_buffers
,
)
if
not
skip_attn_for_dummy_run
:
block_tables
,
slot_mappings
=
self
.
prepare_dummy_attn
(
input_batch
)
...
...
@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs
[
"intermediate_tensors"
]
=
intermediate_tensors
# Run model.
if
cudagraph_runtime
_mode
==
CUDAGraphMode
.
FULL
:
if
batch_desc
.
cg
_mode
==
CUDAGraphMode
.
FULL
:
# Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
model_output
=
self
.
cudagraph_manager
.
run_fullgraph
(
input_batch
.
num_tokens_after_padding
)
model_output
=
self
.
cudagraph_manager
.
run_fullgraph
(
batch_desc
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
else
:
...
...
@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch
.
num_tokens_after_padding
,
cudagraph_runtime_mode
=
cudagraph_runtime
_mode
,
cudagraph_runtime_mode
=
batch_desc
.
cg
_mode
,
num_tokens_across_dp
=
num_tokens_across_dp
,
batch_descriptor
=
batch_descriptor
,
slot_mapping
=
slot_mappings_by_layer
,
...
...
vllm/v1/worker/gpu/model_states/default.py
View file @
483463f7
...
...
@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs
=
input_batch
.
num_reqs_after_padding
num_tokens
=
input_batch
.
num_tokens_after_padding
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_np
)
max_query_len
=
input_batch
.
num_scheduled_tokens
.
max
().
item
()
attn_metadata
=
build_attn_metadata
(
attn_groups
=
attn_groups
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
max_query_len
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py
View file @
483463f7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
Any
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.model_executor.offloader.base
import
get_offloader
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
capture_graphs
,
get_c
uda
g
raph
_sizes
,
BatchExecutionDescriptor
,
C
uda
G
raph
Manager
,
prepare_inputs_to_capture
,
)
from
vllm.v1.worker.gpu.dp_utils
import
make_num_tokens_across_dp
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.utils
import
AttentionGroup
class
EagleCudaGraphManager
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
self
.
vllm_config
=
vllm_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device
=
device
class
EagleCudaGraphManager
(
CudaGraphManager
):
"""CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
compilation_config
is
not
None
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self
.
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
.
decode_mode
()
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_
,
self
.
cudagraph_sizes
=
get_cudagraph_sizes
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
self
.
max_num_reqs
,
self
.
max_num_tokens
,
self
.
cudagraph_mode
,
uniform_decode_query_len
=
1
,
uniform_decode_cudagraph
=
True
,
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
cudagraph_mode
:
CUDAGraphMode
,
draft_tokens
:
torch
.
Tensor
,
):
assert
not
cudagraph_mode
.
has_mode
(
CUDAGraphMode
.
PIECEWISE
),
(
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
self
.
graphs
:
dict
[
int
,
torch
.
cuda
.
CUDAGraph
]
=
{}
self
.
pool
=
None
if
self
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
:
# Eagle always uses uniform decode with query_len=1
super
().
__init__
(
vllm_config
,
device
,
cudagraph_mode
,
decode_query_len
=
1
)
self
.
draft_tokens
=
draft_tokens
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
# the main model's allocations when sharing the same pool.
if
cudagraph_mode
:
self
.
pool
=
torch
.
cuda
.
graph_pool_handle
()
def
get_cudagraph_size
(
self
,
num_tokens
:
int
)
->
int
|
None
:
return
self
.
cudagraph_sizes
.
get
(
num_tokens
)
def
get_cudagraph_runtime_mode
(
self
,
num_tokens
:
int
)
->
tuple
[
CUDAGraphMode
,
int
|
None
]:
cudagraph_size
=
self
.
get_cudagraph_size
(
num_tokens
)
if
cudagraph_size
is
None
:
cudagraph_mode
=
CUDAGraphMode
.
NONE
else
:
cudagraph_mode
=
self
.
cudagraph_mode
if
(
cudagraph_mode
==
CUDAGraphMode
.
FULL
and
cudagraph_size
is
not
None
and
cudagraph_size
not
in
self
.
graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode
=
CUDAGraphMode
.
NONE
cudagraph_size
=
None
return
cudagraph_mode
,
cudagraph_size
def
capture_graph
(
def
capture
(
self
,
num_tokens
:
int
,
capture_cg_mode
:
CUDAGraphMode
,
generate_fn
:
Callable
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
progress_bar_desc
:
str
=
"Capturing CUDA graphs"
,
)
->
None
:
assert
capture_cg_mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
],
(
f
"Invalid capture_cudagraph_mode for capture:
{
capture_cg_mode
}
"
)
if
capture_cg_mode
==
CUDAGraphMode
.
PIECEWISE
:
capture_fn
=
self
.
_capture_piecewise_graph
else
:
capture_fn
=
self
.
_capture_full_graph
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
model_state
,
input_buffers
,
block_tables
,
attn_groups
,
kv_cache_config
,
)
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
dp_size
,
num_tokens
)
# Warm up.
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
NONE
,
)
# Capture the graph.
capture_fn
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
generate_fn
=
generate_fn
,
attn_metadata
=
attn_metadata
,
slot_mappings
=
slot_mappings
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
def
_capture_full_graph
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
generate_fn
:
Callable
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
num_tokens_across_dp
:
torch
.
Tensor
,
)
->
None
:
assert
num_tokens
not
in
self
.
graphs
graph
=
torch
.
cuda
.
CUDAGraph
()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader
().
sync_prev_onload
()
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
def
create_forward_fn
(
desc
:
BatchExecutionDescriptor
,
)
->
Callable
[[
CUDAGraphMode
],
None
]:
num_tokens
=
desc
.
num_tokens
num_reqs
=
desc
.
num_reqs
or
min
(
num_tokens
,
self
.
max_num_reqs
)
num_tokens_across_dp
=
(
torch
.
full
((
self
.
dp_size
,),
num_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
if
self
.
dp_size
>
1
else
None
)
attn_metadata
,
slot_mappings
=
prepare_inputs_to_capture
(
num_reqs
,
num_tokens
,
model_state
,
input_buffers
,
block_tables
,
attn_groups
,
kv_cache_config
,
)
with
torch
.
cuda
.
graph
(
graph
,
self
.
pool
):
generate_fn
(
return
lambda
cg_mode
:
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
NONE
,
cg_mode
,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader
().
join_after_forward
()
self
.
graphs
[
num_tokens
]
=
graph
def
_capture_piecewise_graph
(
self
,
num_reqs
:
int
,
num_tokens
:
int
,
generate_fn
:
Callable
,
attn_metadata
:
dict
[
str
,
Any
],
slot_mappings
:
dict
[
str
,
torch
.
Tensor
],
num_tokens_across_dp
:
torch
.
Tensor
,
)
->
None
:
generate_fn
(
num_reqs
,
num_tokens
,
attn_metadata
,
slot_mappings
,
num_tokens_across_dp
,
CUDAGraphMode
.
PIECEWISE
,
)
@
torch
.
inference_mode
()
def
capture
(
self
,
generate_fn
:
Callable
,
model_state
:
ModelState
,
input_buffers
:
InputBuffers
,
block_tables
:
BlockTables
,
attn_groups
:
list
[
list
[
AttentionGroup
]],
kv_cache_config
:
KVCacheConfig
,
)
->
None
:
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
return
capture_graphs
(
self
.
cudagraph_sizes
,
self
.
device
,
self
.
capture_graph
,
capture_cudagraph_mode
=
self
.
cudagraph_mode
,
desc
=
f
"Capturing eagle CUDA graphs (
{
self
.
cudagraph_mode
.
name
}
)"
,
generate_fn
=
generate_fn
,
model_state
=
model_state
,
input_buffers
=
input_buffers
,
block_tables
=
block_tables
,
attn_groups
=
attn_groups
,
kv_cache_config
=
kv_cache_config
,
)
super
().
capture
(
create_forward_fn
,
progress_bar_desc
)
def
run_fullgraph
(
self
,
num_tokens
:
int
)
->
None
:
assert
num_tokens
in
self
.
graphs
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader
().
sync_prev_onload
()
self
.
graphs
[
num_tokens
].
replay
()
def
run_fullgraph
(
self
,
desc
:
BatchExecutionDescriptor
)
->
torch
.
Tensor
:
"""Replay a captured FULL cudagraph and return draft tokens."""
super
().
run_fullgraph
(
desc
)
return
self
.
draft_tokens
vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
View file @
483463f7
...
...
@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.dp_utils
import
get
_cudagraph_and_dp_padding
from
vllm.v1.worker.gpu.dp_utils
import
sync
_cudagraph_and_dp_padding
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
,
InputBuffers
from
vllm.v1.worker.gpu.model_states.interface
import
ModelState
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
...
...
@@ -75,7 +75,16 @@ class EagleSpeculator:
device
=
device
,
)
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
vllm_config
,
device
)
# currently we don't support PIECEWISE for Eagle.
cudagraph_mode
=
vllm_config
.
compilation_config
.
cudagraph_mode
if
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
:
cudagraph_mode
=
CUDAGraphMode
.
FULL_DECODE_ONLY
else
:
cudagraph_mode
=
CUDAGraphMode
.
NONE
self
.
cudagraph_manager
=
EagleCudaGraphManager
(
vllm_config
,
device
,
cudagraph_mode
,
self
.
draft_tokens
)
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
self
.
model
=
load_eagle_model
(
target_model
,
self
.
vllm_config
)
...
...
@@ -171,7 +180,7 @@ class EagleSpeculator:
)
if
attn_metadata
is
not
None
:
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
pos
idx_mapping
,
query_start_loc
,
pos
,
num_tokens_padded
)
def
capture_model
(
self
)
->
None
:
...
...
@@ -185,6 +194,7 @@ class EagleSpeculator:
self
.
block_tables
,
self
.
attn_groups
,
self
.
kv_cache_config
,
progress_bar_desc
=
"Capturing eagle CUDA graphs"
,
)
@
torch
.
inference_mode
()
...
...
@@ -251,6 +261,7 @@ class EagleSpeculator:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
num_reqs_padded
=
input_batch
.
num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
...
...
@@ -292,48 +303,52 @@ class EagleSpeculator:
self
.
max_num_reqs
,
)
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
pos
)
# Get batch descriptor and sync across DP ranks.
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
cudagraph_mode
,
cudagraph_size
=
(
self
.
cudagraph_manager
.
get_cudagraph_runtime_mode
(
num_reqs
)
)
num_tokens_padded
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
get_cudagraph_and_dp_padding
(
batch_desc
=
self
.
cudagraph_manager
.
dispatch
(
num_reqs
,
num_reqs
,
1
)
num_tokens_across_dp
=
None
if
self
.
dp_size
>
1
:
batch_desc
,
num_tokens_across_dp
=
sync_cudagraph_and_dp_padding
(
self
.
cudagraph_manager
,
batch_desc
,
num_reqs
,
cudagraph_size
,
cudagraph_mode
.
value
,
num_reqs
,
1
,
# uniform_token_count
self
.
dp_size
,
self
.
dp_rank
,
)
)
cudagraph_mode
=
CUDAGraphMode
(
synced_cudagraph_mode
)
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
# Run full CUDA graph.
self
.
cudagraph_manager
.
run_fullgraph
(
num_tokens_padded
)
return
self
.
draft_tokens
[:
num_reqs
]
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
idx_mapping
,
query_start_loc
,
pos
,
batch_desc
.
num_tokens
)
if
batch_desc
.
cg_mode
==
CUDAGraphMode
.
FULL
:
return
self
.
cudagraph_manager
.
run_fullgraph
(
batch_desc
)[:
num_reqs
]
# Run eager or piecewise CUDA graph.
attn_metadata_updated
=
None
slot_mappings_updated
=
None
if
not
(
dummy_run
and
skip_attn_for_dummy_run
):
query_start_loc_cpu
=
torch
.
arange
(
num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
num_reqs
_padded
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
block_tables
=
[
x
[:
num_reqs_padded
]
for
x
in
self
.
block_tables
.
input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated
=
build_attn_metadata
(
attn_groups
=
self
.
attn_groups
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
num_reqs
=
num_reqs
_padded
,
num_tokens
=
num_reqs
_padded
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
max_query_len
=
1
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
_padded
],
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
...
...
@@ -345,11 +360,11 @@ class EagleSpeculator:
self
.
generate_draft
(
num_reqs
,
num_tokens
_padded
,
batch_desc
.
num_tokens
,
attn_metadata_updated
,
slot_mappings_updated
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph
_mode
,
cudagraph_runtime_mode
=
batch_desc
.
cg
_mode
,
)
return
self
.
draft_tokens
[:
num_reqs
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment