Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4778b426
Unverified
Commit
4778b426
authored
Sep 26, 2025
by
Sage Moore
Committed by
GitHub
Sep 26, 2025
Browse files
Reduce the Cuda Graph memory footprint when running with DBO (#25779)
Signed-off-by:
Sage Moore
<
sage@neuralmagic.com
>
parent
c70ac4b8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
28 deletions
+32
-28
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-28
vllm/v1/worker/gpu_ubatch_wrapper.py
vllm/v1/worker/gpu_ubatch_wrapper.py
+12
-0
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
4778b426
...
...
@@ -3477,8 +3477,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We skip EPLB here since we don't want to record dummy metrics
for
num_tokens
in
compilation_cases
:
# We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches.
capture_ubatched_graph
=
self
.
parallel_config
.
enable_dbo
\
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching
=
self
.
parallel_config
.
enable_dbo
\
and
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
\
and
uniform_decode
\
and
check_ubatch_thresholds
(
...
...
@@ -3487,37 +3489,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode
=
uniform_decode
,
)
# Currently we capture both microbatched and non-microbatched
# graphs when capture_ubatched_graph is True, this is because
# occasionally we will be forced out of microbatching due to other
# DP ranks not microbatching (usually caused by an empty second
# microbatch; once we resolve this, we can remove the
# non-microbatched graph capture).
allow_microbatching_options
=
[
True
,
False
]
if
\
capture_ubatched_graph
else
[
False
]
for
allow_microbatching
in
allow_microbatching_options
:
for
_
in
range
(
self
.
compilation_config
.
cudagraph_num_of_warmups
):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention
=
(
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
)
self
.
_dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
force_attention
=
force_attention
,
uniform_decode
=
uniform_decode
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
)
for
_
in
range
(
self
.
compilation_config
.
cudagraph_num_of_warmups
):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention
=
(
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
)
self
.
_dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
force_attention
=
force_attention
,
uniform_decode
=
uniform_decode
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
)
self
.
_dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
uniform_decode
=
uniform_decode
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
)
self
.
maybe_remove_all_loras
(
self
.
lora_config
)
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/gpu_ubatch_wrapper.py
View file @
4778b426
...
...
@@ -330,6 +330,18 @@ class UBatchWrapper:
# If there's no ubatching, just run the runnable object
if
ubatch_slices
is
None
:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if
cudagraph_runtime_mode
is
CUDAGraphMode
.
FULL
:
assert
batch_descriptor
is
not
None
if
batch_descriptor
.
num_tokens
in
self
.
cudagraphs
:
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
if
cudagraph_runtime_mode
in
(
CUDAGraphMode
.
NONE
,
CUDAGraphMode
.
PIECEWISE
):
return
self
.
runnable
(
*
args
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment