Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed333497
Unverified
Commit
ed333497
authored
Jun 18, 2025
by
Richard Zou
Committed by
GitHub
Jun 19, 2025
Browse files
[BugFix] Fix use_cudagraph=False (#19612)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
d49adea1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
29 deletions
+35
-29
tests/compile/test_config.py
tests/compile/test_config.py
+21
-24
vllm/compilation/counter.py
vllm/compilation/counter.py
+3
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-5
No files found.
tests/compile/test_config.py
View file @
ed333497
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
vllm
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
from
.piecewise.test_simple
import
SillyModel
from
vllm.config
import
VllmConfig
def
test_use_cudagraphs_dynamic
(
monkeypatch
):
...
...
@@ -22,23 +18,24 @@ def test_use_cudagraphs_dynamic(monkeypatch):
@
pytest
.
mark
.
parametrize
(
"enabled"
,
[
True
,
False
])
def
test_use_cudagraphs
(
enabled
):
def
test_use_cudagraphs
(
vllm_runner
,
monkeypatch
,
enabled
):
assert
vllm
.
envs
.
VLLM_USE_V1
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
enabled
,
cudagraph_capture_sizes
=
[
100
],
))
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
inputs
=
torch
.
randn
(
100
,
device
=
"cuda"
)
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_cudagraph_captured
=
1
if
enabled
else
0
,
):
# first run is warmup
model
(
inputs
)
# second run does CUDAGraphs recording (if enabled)
model
(
inputs
)
# Disable multiprocessing so that the counter is in the same process
monkeypatch
.
setenv
(
'VLLM_ENABLE_V1_MULTIPROCESSING'
,
'0'
)
compilation_config
=
{
"cudagraph_capture_sizes"
:
[
100
],
"use_cudagraph"
:
enabled
,
}
with
(
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_gpu_runner_capture_triggers
=
1
if
enabled
else
0
,
num_cudagraph_captured
=
13
if
enabled
else
0
,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner
(
'facebook/opt-125m'
,
compilation_config
=
compilation_config
,
gpu_memory_utilization
=
0.4
)
as
_
):
pass
vllm/compilation/counter.py
View file @
ed333497
...
...
@@ -15,6 +15,9 @@ class CompilationCounter:
# not including the splitting ops
num_piecewise_capturable_graphs_seen
:
int
=
0
num_backend_compilations
:
int
=
0
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
num_gpu_runner_capture_triggers
:
int
=
0
# Number of CUDAGraphs captured
num_cudagraph_captured
:
int
=
0
# InductorAdapter.compile calls
num_inductor_compiles
:
int
=
0
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ed333497
...
...
@@ -18,6 +18,7 @@ import vllm.envs as envs
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.layer
import
Attention
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
...
...
@@ -200,8 +201,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes
=
[
self
.
cache_config
.
block_size
],
)
self
.
use_cuda_graph
=
(
self
.
compilation_config
.
level
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
and
self
.
vllm_config
.
compilation_config
.
use_cudagraph
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
...
...
@@ -2058,10 +2061,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
capture_model
(
self
)
->
None
:
if
not
self
.
use_cuda_graph
:
logger
.
warning
(
"Skipping CUDA graph capture. Please add "
"-O %s to use CUDA graphs."
,
CompilationLevel
.
PIECEWISE
)
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to "
"False"
,
CompilationLevel
.
PIECEWISE
)
return
compilation_counter
.
num_gpu_runner_capture_triggers
+=
1
start_time
=
time
.
perf_counter
()
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment