Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09b6f998
Unverified
Commit
09b6f998
authored
Mar 11, 2026
by
Richard Zou
Committed by
GitHub
Mar 11, 2026
Browse files
[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
c87fb515
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
182 additions
and
43 deletions
+182
-43
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+114
-0
vllm/compilation/counter.py
vllm/compilation/counter.py
+6
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+59
-43
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+3
-0
No files found.
tests/compile/test_aot_compile.py
View file @
09b6f998
...
...
@@ -4,6 +4,7 @@
import
functools
import
hashlib
import
multiprocessing
import
os
import
pickle
import
tempfile
from
contextlib
import
contextmanager
...
...
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
StandaloneCompiledArtifacts
,
VllmSerializableFunction
,
)
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
...
...
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
assert
isinstance
(
config
,
dict
)
assert
"bundled_autograd_cache"
in
config
assert
config
[
"bundled_autograd_cache"
]
is
True
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_disable_compile_cache_skips_aot_save
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
expected
=
reference_fn
(
*
args
)
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
0
,
),
):
mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
actual
=
mod
(
*
args
)
assert
torch
.
allclose
(
actual
,
expected
)
# No cached artifact should exist on disk
aot_dir
=
os
.
path
.
join
(
fresh_vllm_cache
,
"torch_compile_cache"
,
"torch_aot_compile"
)
if
os
.
path
.
isdir
(
aot_dir
):
for
root
,
_dirs
,
files
in
os
.
walk
(
aot_dir
):
for
f
in
files
:
assert
f
!=
"model"
,
(
f
"AOT artifact unexpectedly saved at
{
os
.
path
.
join
(
root
,
f
)
}
"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_disable_compile_cache_skips_aot_load
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
# Phase 1: compile and save with cache enabled
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_artifacts_saved
=
1
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
# Phase 2: disable cache, compile again — should NOT load from disk
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
disable_envs_cache
()
torch
.
_dynamo
.
reset
()
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
0
,
),
):
mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
mod
(
*
args
)
assert
not
mod
.
was_aot_compile_fn_loaded_from_disk
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_aot_counters_on_save_and_load
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""Verify AOT counters are incremented correctly on save and load."""
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
# Phase 1: fresh compile + save
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
1
,
num_aot_artifacts_loaded
=
0
,
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
# Phase 2: load from cache
monkeypatch
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
disable_envs_cache
()
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
0
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
1
,
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
vllm/compilation/counter.py
View file @
09b6f998
...
...
@@ -31,6 +31,12 @@ class CompilationCounter:
num_compiled_artifacts_saved
:
int
=
0
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded
:
int
=
0
# The number of AOT compile invocations
num_aot_compiles
:
int
=
0
# The number of AOT compiled artifacts saved to disk
num_aot_artifacts_saved
:
int
=
0
# The number of AOT compiled artifacts loaded from disk
num_aot_artifacts_loaded
:
int
=
0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count
:
int
=
0
...
...
vllm/compilation/decorators.py
View file @
09b6f998
...
...
@@ -266,6 +266,51 @@ def _verify_source_unchanged(
)
def
_try_load_aot_compiled_fn
(
model
:
Any
,
aot_compilation_path
:
str
,
)
->
Any
|
None
:
"""Try to load an AOT-compiled function from disk.
Returns the loaded callable on success, or None on failure.
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
"""
try
:
with
monitor_torch_compile
(
model
.
vllm_config
):
with
(
set_current_vllm_config
(
model
.
vllm_config
),
open
(
aot_compilation_path
,
"rb"
)
as
f
,
):
loaded_fn
=
torch
.
compiler
.
load_compiled_function
(
f
,
f_globals
=
model
.
forward
.
__globals__
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
model
.
vllm_config
)
ds_config
=
model
.
compilation_config
.
dynamic_shapes_config
if
not
ds_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with
maybe_use_cudagraph_partition_wrapper
(
model
.
vllm_config
):
loaded_fn
.
_artifacts
.
compiled_fn
.
finalize_loading
(
model
.
vllm_config
)
compilation_counter
.
num_aot_artifacts_loaded
+=
1
logger
.
info
(
"Directly load AOT compilation from path %s"
,
aot_compilation_path
)
return
loaded_fn
except
Exception
as
e
:
if
os
.
path
.
exists
(
aot_compilation_path
):
if
isinstance
(
e
,
EOFError
):
message
=
"Compile cache file corrupted."
else
:
message
=
str
(
e
)
logger
.
warning
(
"Compiling model again due to a load failure from %s, reason: %s"
,
aot_compilation_path
,
message
,
)
if
envs
.
VLLM_FORCE_AOT_LOAD
:
raise
e
return
None
def
_support_torch_compile
(
cls
:
type
[
_T
],
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]],
...
...
@@ -438,51 +483,17 @@ def _support_torch_compile(
dp_rank
=
self
.
vllm_config
.
parallel_config
.
data_parallel_index
cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
aot_compilation_path
=
os
.
path
.
join
(
cache_dir
,
"model"
)
try
:
with
monitor_torch_compile
(
self
.
vllm_config
):
if
not
envs
.
VLLM_DISABLE_COMPILE_CACHE
:
loaded_fn
=
_try_load_aot_compiled_fn
(
self
,
aot_compilation_path
)
if
loaded_fn
is
not
None
:
self
.
aot_compiled_fn
=
loaded_fn
self
.
was_aot_compile_fn_loaded_from_disk
=
True
with
(
set_current_vllm_config
(
self
.
vllm_config
),
open
(
aot_compilation_path
,
"rb"
)
as
f
,
monitor_profiling_run
(
),
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
)
,
):
loaded_fn
=
torch
.
compiler
.
load_compiled_function
(
f
,
f_globals
=
self
.
forward
.
__globals__
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
self
.
vllm_config
)
ds_config
=
self
.
compilation_config
.
dynamic_shapes_config
if
not
ds_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
):
loaded_fn
.
_artifacts
.
compiled_fn
.
finalize_loading
(
self
.
vllm_config
)
self
.
aot_compiled_fn
=
loaded_fn
self
.
was_aot_compile_fn_loaded_from_disk
=
True
except
Exception
as
e
:
if
os
.
path
.
exists
(
aot_compilation_path
):
if
isinstance
(
e
,
EOFError
):
message
=
"Compile cache file corrupted."
else
:
message
=
str
(
e
)
logger
.
warning
(
"Compiling model again due to a load failure from %s, "
"reason: %s"
,
aot_compilation_path
,
message
,
)
if
envs
.
VLLM_FORCE_AOT_LOAD
:
raise
e
if
getattr
(
self
,
"aot_compiled_fn"
,
None
)
is
not
None
:
logger
.
info
(
"Directly load AOT compilation from path %s"
,
aot_compilation_path
)
with
(
monitor_profiling_run
(),
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
),
):
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
return
output
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
return
output
if
self
.
compiled
:
assert
(
...
...
@@ -570,6 +581,7 @@ def _support_torch_compile(
self
.
_aot_cache_dir
=
cache_dir
with
monitor_torch_compile
(
self
.
vllm_config
):
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
compilation_counter
.
num_aot_compiles
+=
1
# All compilation is done at this point, save the
# AOT artifact.
self
.
save_aot_compiled_function
()
...
...
@@ -593,6 +605,9 @@ def _support_torch_compile(
# triggers VllmSerializableFunction.serialize()
def
save_aot_compiled_function
(
self
:
type
[
_T
])
->
None
:
if
envs
.
VLLM_DISABLE_COMPILE_CACHE
:
return
if
self
.
was_aot_compile_fn_loaded_from_disk
:
logger
.
debug
(
"AOT compiled function was loaded from cache, skipping save"
)
return
...
...
@@ -608,6 +623,7 @@ def _support_torch_compile(
tmp_file
=
f
"
{
self
.
_aot_compilation_path
}
.
{
os
.
getpid
()
}
.tmp"
self
.
aot_compiled_fn
.
save_compiled_function
(
tmp_file
)
os
.
replace
(
tmp_file
,
self
.
_aot_compilation_path
)
compilation_counter
.
num_aot_artifacts_saved
+=
1
logger
.
info_once
(
"saved AOT compiled function to %s"
,
self
.
_aot_compilation_path
,
...
...
vllm/compilation/wrapper.py
View file @
09b6f998
...
...
@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
compilation_counter
.
num_cache_entries_updated
=
0
compilation_counter
.
num_compiled_artifacts_saved
=
0
compilation_counter
.
stock_torch_compile_count
=
0
compilation_counter
.
num_aot_compiles
=
0
compilation_counter
.
num_aot_artifacts_saved
=
0
compilation_counter
.
num_aot_artifacts_loaded
=
0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment