Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a414595
Unverified
Commit
3a414595
authored
Jan 23, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 23, 2026
Browse files
[cudagraphs] Refactor cudagraph capture loop (#32946)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
8518b304
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
117 additions
and
59 deletions
+117
-59
tests/v1/cudagraph/test_cudagraph_dispatch.py
tests/v1/cudagraph/test_cudagraph_dispatch.py
+62
-0
vllm/v1/cudagraph_dispatcher.py
vllm/v1/cudagraph_dispatcher.py
+23
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+32
-59
No files found.
tests/v1/cudagraph/test_cudagraph_dispatch.py
View file @
3a414595
...
...
@@ -173,6 +173,68 @@ class TestCudagraphDispatcher:
else
:
assert
rt_mode
==
CUDAGraphMode
.
NONE
@
pytest
.
mark
.
parametrize
(
"cudagraph_mode_str,compilation_mode,expected_modes"
,
[
# FULL mode: only FULL keys, no PIECEWISE
(
"FULL"
,
CompilationMode
.
NONE
,
[
CUDAGraphMode
.
FULL
]),
# PIECEWISE mode: only PIECEWISE keys
(
"PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
[
CUDAGraphMode
.
PIECEWISE
]),
# FULL_DECODE_ONLY: only FULL keys for uniform decode
(
"FULL_DECODE_ONLY"
,
CompilationMode
.
NONE
,
[
CUDAGraphMode
.
FULL
]),
# NONE mode: no keys
(
"NONE"
,
CompilationMode
.
NONE
,
[]),
],
)
def
test_get_capture_descs
(
self
,
cudagraph_mode_str
,
compilation_mode
,
expected_modes
):
"""Test get_capture_descs returns correctly grouped and ordered descs."""
comp_config
=
CompilationConfig
(
cudagraph_mode
=
cudagraph_mode_str
,
mode
=
compilation_mode
,
cudagraph_capture_sizes
=
[
1
,
4
,
8
,
16
],
)
config
=
_create_vllm_config
(
comp_config
,
max_num_seqs
=
16
)
dispatcher
=
CudagraphDispatcher
(
config
)
dispatcher
.
initialize_cudagraph_keys
(
cudagraph_mode
=
comp_config
.
cudagraph_mode
,
uniform_decode_query_len
=
1
)
capture_descs
=
dispatcher
.
get_capture_descs
()
# Verify we get the expected modes
actual_modes
=
[
mode
for
mode
,
_
in
capture_descs
]
assert
actual_modes
==
expected_modes
# Verify each group is sorted largest-first
for
mode
,
descs
in
capture_descs
:
assert
len
(
descs
)
>
0
,
"Each group should have at least one descriptor"
num_tokens_list
=
[
d
.
num_tokens
for
d
in
descs
]
assert
num_tokens_list
==
sorted
(
num_tokens_list
,
reverse
=
True
),
(
f
"Descriptors for
{
mode
}
should be sorted largest-first"
)
# All descriptors in a group should have same uniform value
uniform_values
=
[
d
.
uniform
for
d
in
descs
]
assert
len
(
set
(
uniform_values
))
==
1
,
(
"All descriptors in a group should have the same uniform value"
)
def
test_get_capture_descs_empty_when_not_initialized
(
self
):
"""Test that get_capture_descs returns empty list when keys not initialized."""
comp_config
=
CompilationConfig
(
cudagraph_mode
=
"FULL"
,
mode
=
CompilationMode
.
NONE
,
cudagraph_capture_sizes
=
[
1
,
8
],
)
config
=
_create_vllm_config
(
comp_config
,
max_num_seqs
=
8
)
dispatcher
=
CudagraphDispatcher
(
config
)
# Don't initialize keys
assert
dispatcher
.
get_capture_descs
()
==
[]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
class
TestCUDAGraphWrapper
:
...
...
vllm/v1/cudagraph_dispatcher.py
View file @
3a414595
...
...
@@ -231,3 +231,26 @@ class CudagraphDispatcher:
# finally, just return no cudagraphs and a trivial batch descriptor
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
def
get_capture_descs
(
self
)
->
list
[
tuple
[
CUDAGraphMode
,
list
[
BatchDescriptor
]]]:
"""
Returns capture descriptors for cudagraph capturing.
Returns:
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
first then FULL. Batch descriptors are sorted largest-first for
memory efficiency.
"""
if
not
self
.
keys_initialized
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
return
[]
result
=
[]
# Return in order: PIECEWISE first, then FULL
for
mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]:
descs
=
list
(
self
.
cudagraph_keys
[
mode
])
if
descs
:
# Sort by num_tokens descending (largest first)
descs
.
sort
(
key
=
lambda
d
:
d
.
num_tokens
,
reverse
=
True
)
result
.
append
((
mode
,
descs
))
return
result
vllm/v1/worker/gpu_model_runner.py
View file @
3a414595
...
...
@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
from
contextlib
import
contextmanager
from
copy
import
copy
,
deepcopy
from
functools
import
reduce
from
itertools
import
product
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
TypeAlias
,
cast
import
numpy
as
np
...
...
@@ -4839,50 +4838,14 @@ class GPUModelRunner(
set_cudagraph_capturing_enabled
(
True
)
with
freeze_gc
(),
graph_capture
(
device
=
self
.
device
):
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
cudagraph_mode
=
self
.
compilation_config
.
cudagraph_mode
assert
cudagraph_mode
is
not
None
if
self
.
lora_config
:
if
self
.
compilation_config
.
cudagraph_specialize_lora
:
lora_cases
=
[
True
,
False
]
else
:
lora_cases
=
[
True
]
else
:
lora_cases
=
[
False
]
if
cudagraph_mode
.
mixed_mode
()
!=
CUDAGraphMode
.
NONE
:
cudagraph_runtime_mode
=
cudagraph_mode
.
mixed_mode
()
# make sure we capture the largest batch size first
compilation_cases
=
list
(
product
(
reversed
(
self
.
cudagraph_batch_sizes
),
lora_cases
)
)
self
.
_capture_cudagraphs
(
compilation_cases
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
uniform_decode
=
False
,
)
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if
(
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
cudagraph_mode
.
separate_routine
()
):
max_num_tokens
=
(
self
.
scheduler_config
.
max_num_seqs
*
self
.
uniform_decode_query_len
)
decode_cudagraph_batch_sizes
=
[
x
for
x
in
self
.
cudagraph_batch_sizes
if
max_num_tokens
>=
x
>=
self
.
uniform_decode_query_len
]
compilation_cases_decode
=
list
(
product
(
reversed
(
decode_cudagraph_batch_sizes
),
lora_cases
)
)
for
(
runtime_mode
,
batch_descs
,
)
in
self
.
cudagraph_dispatcher
.
get_capture_descs
():
self
.
_capture_cudagraphs
(
compilation_cases
=
compilation_cases_decode
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
FULL
,
uniform_decode
=
True
,
batch_descriptors
=
batch_descs
,
cudagraph_runtime_mode
=
runtime_mode
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -4913,19 +4876,32 @@ class GPUModelRunner(
def
_capture_cudagraphs
(
self
,
compilation_cases
:
list
[
tuple
[
int
,
bool
]
],
batch_descriptors
:
list
[
BatchDescriptor
],
cudagraph_runtime_mode
:
CUDAGraphMode
,
uniform_decode
:
bool
,
):
assert
(
cudagraph_runtime_mode
!=
CUDAGraphMode
.
NONE
and
cudagraph_runtime_mode
.
valid_runtime_modes
()
),
f
"Invalid cudagraph runtime mode:
{
cudagraph_runtime_mode
}
"
if
not
batch_descriptors
:
return
uniform_decode
=
batch_descriptors
[
0
].
uniform
force_attention
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
dummy_run
=
functools
.
partial
(
self
.
_dummy_run
,
uniform_decode
=
uniform_decode
,
skip_eplb
=
True
,
remove_lora
=
False
,
force_attention
=
force_attention
,
)
# Only rank 0 should print progress bar during capture
if
is_global_first_rank
():
compilation_case
s
=
tqdm
(
compilation_case
s
,
batch_descriptor
s
=
tqdm
(
batch_descriptor
s
,
disable
=
not
self
.
load_config
.
use_tqdm_on_load
,
desc
=
"Capturing CUDA graphs ({}, {})"
.
format
(
"decode"
if
uniform_decode
else
"mixed prefill-decode"
,
...
...
@@ -4934,7 +4910,10 @@ class GPUModelRunner(
)
# We skip EPLB here since we don't want to record dummy metrics
for
num_tokens
,
activate_lora
in
compilation_cases
:
for
batch_desc
in
batch_descriptors
:
num_tokens
=
batch_desc
.
num_tokens
activate_lora
=
batch_desc
.
has_lora
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
...
...
@@ -4952,28 +4931,22 @@ class GPUModelRunner(
for
_
in
range
(
self
.
compilation_config
.
cudagraph_num_of_warmups
):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# But be careful, warm up with `NONE`
is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
self
.
_dummy_run
(
dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
force_attention
=
force_attention
,
uniform_decode
=
uniform_decode
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
,
activate_lora
=
activate_lora
,
)
self
.
_dummy_run
(
# Capture run
dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
uniform_decode
=
uniform_decode
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
,
activate_lora
=
activate_lora
,
is_graph_capturing
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment