Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d1a6e96d
Unverified
Commit
d1a6e96d
authored
Mar 02, 2026
by
Richard Zou
Committed by
GitHub
Mar 02, 2026
Browse files
[torch.compile] Improve cold and warm start compile tests (#35709)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
2a9e3347
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
48 deletions
+82
-48
tests/compile/test_cold_start.py
tests/compile/test_cold_start.py
+0
-48
tests/compile/test_startup.py
tests/compile/test_startup.py
+71
-0
tests/conftest.py
tests/conftest.py
+8
-0
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+1
-0
vllm/compilation/counter.py
vllm/compilation/counter.py
+2
-0
No files found.
tests/compile/test_cold_start.py
deleted
100644 → 0
View file @
2a9e3347
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
torch._dynamo.utils
import
counters
from
vllm
import
LLM
from
vllm.config
import
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
def
test_moe_compilation_cold_start
(
monkeypatch
,
use_fresh_inductor_cache
):
# Run in same process so we can access PyTorch's internal counters
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
# I'm not sure if this is going to affect the numbers
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"0"
)
# Force cold compilation
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
# make the model loading faster
)
counters
.
clear
()
_
=
LLM
(
model
=
"microsoft/Phi-tiny-MoE-instruct"
,
max_model_len
=
256
,
load_format
=
"dummy"
,
# make the model loading faster
compilation_config
=
compilation_config
,
num_gpu_blocks_override
=
8
,
# make the model loading faster
)
# vLLM-compile cold start is special. By default, we do
# one full dynamo capture of the entire forward pass.
# The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation).
# We then generate compiled artifacts for the unique subgraphs.
#
# There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models.
# So we test that during cold start, we are only compling
# for 3 unique subgraphs.
assert
counters
[
"aot_autograd"
][
"autograd_cache_miss"
]
==
3
assert
counters
[
"aot_autograd"
][
"autograd_cache_hit"
]
==
0
tests/compile/test_startup.py
0 → 100644
View file @
d1a6e96d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Cold start and warm start tests for vLLM-compile.
Cold start runs in a forked child (must fork before CUDA init) which
populates on-disk caches and asserts cold-start counters. Warm start
then runs in the parent with clean in-memory state but populated caches.
"""
import
multiprocessing
as
mp
from
torch._dynamo.utils
import
counters
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
MODEL
=
"microsoft/Phi-tiny-MoE-instruct"
def
_run_vllm
(
vllm_runner
):
with
vllm_runner
(
MODEL
,
trust_remote_code
=
False
,
max_model_len
=
256
,
max_num_batched_tokens
=
1024
,
load_format
=
"dummy"
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
),
num_gpu_blocks_override
=
8
,
):
pass
def
_cold_start
(
vllm_runner
):
counters
.
clear
()
with
compilation_counter
.
expect
(
num_compiled_artifacts_saved
=
3
,
num_compiled_artifacts_loaded
=
0
,
):
_run_vllm
(
vllm_runner
)
assert
counters
[
"aot_autograd"
][
"total"
]
==
33
assert
counters
[
"aot_autograd"
][
"autograd_cache_miss"
]
==
3
assert
counters
[
"aot_autograd"
][
"autograd_cache_hit"
]
==
0
def
test_moe_startup
(
monkeypatch
,
vllm_runner
,
fresh_vllm_cache
):
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
# Cold start in a forked child (must fork before CUDA init).
# This model has 32 identical transformer layers which produce
# 33 subgraphs after splitting on attention — only 3 are unique.
ctx
=
mp
.
get_context
(
"fork"
)
p
=
ctx
.
Process
(
target
=
_cold_start
,
args
=
(
vllm_runner
,))
p
.
start
()
p
.
join
()
assert
p
.
exitcode
==
0
,
"Cold-start child failed"
# Warm start — compiled artifacts loaded from disk cache.
counters
.
clear
()
with
compilation_counter
.
expect
(
num_compiled_artifacts_loaded
=
3
,
# TODO: warm start should not save any artifacts
# https://github.com/vllm-project/vllm/issues/35708
num_compiled_artifacts_saved
=
1
,
):
_run_vllm
(
vllm_runner
)
assert
counters
[
"aot_autograd"
][
"total"
]
==
30
assert
counters
[
"aot_autograd"
][
"autograd_cache_miss"
]
==
0
assert
counters
[
"aot_autograd"
][
"autograd_cache_hit"
]
==
1
tests/conftest.py
View file @
d1a6e96d
...
@@ -1548,6 +1548,14 @@ def use_fresh_inductor_cache():
...
@@ -1548,6 +1548,14 @@ def use_fresh_inductor_cache():
yield
yield
@
pytest
.
fixture
def
fresh_vllm_cache
(
monkeypatch
,
use_fresh_inductor_cache
):
"""Temporary VLLM_CACHE_ROOT combined with a fresh inductor cache."""
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
monkeypatch
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmp_dir
)
yield
tmp_dir
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
enable_pickle
(
monkeypatch
):
def
enable_pickle
(
monkeypatch
):
"""`LLM.apply_model` requires pickling a function."""
"""`LLM.apply_model` requires pickling a function."""
...
...
vllm/compilation/compiler_interface.py
View file @
d1a6e96d
...
@@ -368,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -368,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
inductor_compiled_graph
=
torch
.
_inductor
.
CompiledArtifact
.
load
(
inductor_compiled_graph
=
torch
.
_inductor
.
CompiledArtifact
.
load
(
path
=
path
,
format
=
self
.
save_format
path
=
path
,
format
=
self
.
save_format
)
)
compilation_counter
.
num_compiled_artifacts_loaded
+=
1
from
torch._inductor.compile_fx
import
graph_returns_tuple
from
torch._inductor.compile_fx
import
graph_returns_tuple
returns_tuple
=
graph_returns_tuple
(
graph
)
returns_tuple
=
graph_returns_tuple
(
graph
)
...
...
vllm/compilation/counter.py
View file @
d1a6e96d
...
@@ -29,6 +29,8 @@ class CompilationCounter:
...
@@ -29,6 +29,8 @@ class CompilationCounter:
num_cache_entries_updated
:
int
=
0
num_cache_entries_updated
:
int
=
0
# The number of standalone_compile compiled artifacts saved
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved
:
int
=
0
num_compiled_artifacts_saved
:
int
=
0
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded
:
int
=
0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count
:
int
=
0
stock_torch_compile_count
:
int
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment