Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09b6f998
Unverified
Commit
09b6f998
authored
Mar 11, 2026
by
Richard Zou
Committed by
GitHub
Mar 11, 2026
Browse files
[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
c87fb515
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
182 additions
and
43 deletions
+182
-43
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+114
-0
vllm/compilation/counter.py
vllm/compilation/counter.py
+6
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+59
-43
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+3
-0
No files found.
tests/compile/test_aot_compile.py
View file @
09b6f998
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
functools
import
functools
import
hashlib
import
hashlib
import
multiprocessing
import
multiprocessing
import
os
import
pickle
import
pickle
import
tempfile
import
tempfile
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
...
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
StandaloneCompiledArtifacts
,
StandaloneCompiledArtifacts
,
VllmSerializableFunction
,
VllmSerializableFunction
,
)
)
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
CompilationConfig
,
CompilationConfig
,
...
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
...
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
assert
isinstance
(
config
,
dict
)
assert
isinstance
(
config
,
dict
)
assert
"bundled_autograd_cache"
in
config
assert
"bundled_autograd_cache"
in
config
assert
config
[
"bundled_autograd_cache"
]
is
True
assert
config
[
"bundled_autograd_cache"
]
is
True
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_disable_compile_cache_skips_aot_save
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
expected
=
reference_fn
(
*
args
)
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
0
,
),
):
mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
actual
=
mod
(
*
args
)
assert
torch
.
allclose
(
actual
,
expected
)
# No cached artifact should exist on disk
aot_dir
=
os
.
path
.
join
(
fresh_vllm_cache
,
"torch_compile_cache"
,
"torch_aot_compile"
)
if
os
.
path
.
isdir
(
aot_dir
):
for
root
,
_dirs
,
files
in
os
.
walk
(
aot_dir
):
for
f
in
files
:
assert
f
!=
"model"
,
(
f
"AOT artifact unexpectedly saved at
{
os
.
path
.
join
(
root
,
f
)
}
"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_disable_compile_cache_skips_aot_load
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
# Phase 1: compile and save with cache enabled
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_artifacts_saved
=
1
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
# Phase 2: disable cache, compile again — should NOT load from disk
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
disable_envs_cache
()
torch
.
_dynamo
.
reset
()
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
0
,
),
):
mod
=
CompiledMod
(
vllm_config
=
vllm_config
)
mod
(
*
args
)
assert
not
mod
.
was_aot_compile_fn_loaded_from_disk
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_aot_counters_on_save_and_load
(
monkeypatch
:
pytest
.
MonkeyPatch
,
fresh_vllm_cache
:
str
):
"""Verify AOT counters are incremented correctly on save and load."""
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
disable_envs_cache
()
args
=
(
torch
.
randn
(
10
,
10
),)
# Phase 1: fresh compile + save
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
1
,
num_aot_artifacts_saved
=
1
,
num_aot_artifacts_loaded
=
0
,
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
# Phase 2: load from cache
monkeypatch
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
disable_envs_cache
()
vllm_config
=
make_vllm_config
()
with
(
use_vllm_config
(
vllm_config
),
compilation_counter
.
expect
(
num_aot_compiles
=
0
,
num_aot_artifacts_saved
=
0
,
num_aot_artifacts_loaded
=
1
,
),
):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
vllm/compilation/counter.py
View file @
09b6f998
...
@@ -31,6 +31,12 @@ class CompilationCounter:
...
@@ -31,6 +31,12 @@ class CompilationCounter:
num_compiled_artifacts_saved
:
int
=
0
num_compiled_artifacts_saved
:
int
=
0
# The number of standalone_compile compiled artifacts loaded from cache
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded
:
int
=
0
num_compiled_artifacts_loaded
:
int
=
0
# The number of AOT compile invocations
num_aot_compiles
:
int
=
0
# The number of AOT compiled artifacts saved to disk
num_aot_artifacts_saved
:
int
=
0
# The number of AOT compiled artifacts loaded from disk
num_aot_artifacts_loaded
:
int
=
0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count
:
int
=
0
stock_torch_compile_count
:
int
=
0
...
...
vllm/compilation/decorators.py
View file @
09b6f998
...
@@ -266,6 +266,51 @@ def _verify_source_unchanged(
...
@@ -266,6 +266,51 @@ def _verify_source_unchanged(
)
)
def
_try_load_aot_compiled_fn
(
model
:
Any
,
aot_compilation_path
:
str
,
)
->
Any
|
None
:
"""Try to load an AOT-compiled function from disk.
Returns the loaded callable on success, or None on failure.
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
"""
try
:
with
monitor_torch_compile
(
model
.
vllm_config
):
with
(
set_current_vllm_config
(
model
.
vllm_config
),
open
(
aot_compilation_path
,
"rb"
)
as
f
,
):
loaded_fn
=
torch
.
compiler
.
load_compiled_function
(
f
,
f_globals
=
model
.
forward
.
__globals__
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
model
.
vllm_config
)
ds_config
=
model
.
compilation_config
.
dynamic_shapes_config
if
not
ds_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with
maybe_use_cudagraph_partition_wrapper
(
model
.
vllm_config
):
loaded_fn
.
_artifacts
.
compiled_fn
.
finalize_loading
(
model
.
vllm_config
)
compilation_counter
.
num_aot_artifacts_loaded
+=
1
logger
.
info
(
"Directly load AOT compilation from path %s"
,
aot_compilation_path
)
return
loaded_fn
except
Exception
as
e
:
if
os
.
path
.
exists
(
aot_compilation_path
):
if
isinstance
(
e
,
EOFError
):
message
=
"Compile cache file corrupted."
else
:
message
=
str
(
e
)
logger
.
warning
(
"Compiling model again due to a load failure from %s, reason: %s"
,
aot_compilation_path
,
message
,
)
if
envs
.
VLLM_FORCE_AOT_LOAD
:
raise
e
return
None
def
_support_torch_compile
(
def
_support_torch_compile
(
cls
:
type
[
_T
],
cls
:
type
[
_T
],
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]],
dynamic_arg_dims
:
dict
[
str
,
int
|
list
[
int
]],
...
@@ -438,51 +483,17 @@ def _support_torch_compile(
...
@@ -438,51 +483,17 @@ def _support_torch_compile(
dp_rank
=
self
.
vllm_config
.
parallel_config
.
data_parallel_index
dp_rank
=
self
.
vllm_config
.
parallel_config
.
data_parallel_index
cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
aot_compilation_path
=
os
.
path
.
join
(
cache_dir
,
"model"
)
aot_compilation_path
=
os
.
path
.
join
(
cache_dir
,
"model"
)
try
:
if
not
envs
.
VLLM_DISABLE_COMPILE_CACHE
:
with
monitor_torch_compile
(
self
.
vllm_config
):
loaded_fn
=
_try_load_aot_compiled_fn
(
self
,
aot_compilation_path
)
if
loaded_fn
is
not
None
:
self
.
aot_compiled_fn
=
loaded_fn
self
.
was_aot_compile_fn_loaded_from_disk
=
True
with
(
with
(
set_current_vllm_config
(
self
.
vllm_config
),
monitor_profiling_run
(
),
open
(
aot_compilation_path
,
"rb"
)
as
f
,
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
)
,
):
):
loaded_fn
=
torch
.
compiler
.
load_compiled_function
(
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
f
,
f_globals
=
self
.
forward
.
__globals__
return
output
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
self
.
vllm_config
)
ds_config
=
self
.
compilation_config
.
dynamic_shapes_config
if
not
ds_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
):
loaded_fn
.
_artifacts
.
compiled_fn
.
finalize_loading
(
self
.
vllm_config
)
self
.
aot_compiled_fn
=
loaded_fn
self
.
was_aot_compile_fn_loaded_from_disk
=
True
except
Exception
as
e
:
if
os
.
path
.
exists
(
aot_compilation_path
):
if
isinstance
(
e
,
EOFError
):
message
=
"Compile cache file corrupted."
else
:
message
=
str
(
e
)
logger
.
warning
(
"Compiling model again due to a load failure from %s, "
"reason: %s"
,
aot_compilation_path
,
message
,
)
if
envs
.
VLLM_FORCE_AOT_LOAD
:
raise
e
if
getattr
(
self
,
"aot_compiled_fn"
,
None
)
is
not
None
:
logger
.
info
(
"Directly load AOT compilation from path %s"
,
aot_compilation_path
)
with
(
monitor_profiling_run
(),
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
),
):
output
=
self
.
aot_compiled_fn
(
self
,
*
args
,
**
kwargs
)
return
output
if
self
.
compiled
:
if
self
.
compiled
:
assert
(
assert
(
...
@@ -570,6 +581,7 @@ def _support_torch_compile(
...
@@ -570,6 +581,7 @@ def _support_torch_compile(
self
.
_aot_cache_dir
=
cache_dir
self
.
_aot_cache_dir
=
cache_dir
with
monitor_torch_compile
(
self
.
vllm_config
):
with
monitor_torch_compile
(
self
.
vllm_config
):
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
self
.
aot_compiled_fn
=
self
.
aot_compile
(
*
args
,
**
kwargs
)
compilation_counter
.
num_aot_compiles
+=
1
# All compilation is done at this point, save the
# All compilation is done at this point, save the
# AOT artifact.
# AOT artifact.
self
.
save_aot_compiled_function
()
self
.
save_aot_compiled_function
()
...
@@ -593,6 +605,9 @@ def _support_torch_compile(
...
@@ -593,6 +605,9 @@ def _support_torch_compile(
# triggers VllmSerializableFunction.serialize()
# triggers VllmSerializableFunction.serialize()
def
save_aot_compiled_function
(
self
:
type
[
_T
])
->
None
:
def
save_aot_compiled_function
(
self
:
type
[
_T
])
->
None
:
if
envs
.
VLLM_DISABLE_COMPILE_CACHE
:
return
if
self
.
was_aot_compile_fn_loaded_from_disk
:
if
self
.
was_aot_compile_fn_loaded_from_disk
:
logger
.
debug
(
"AOT compiled function was loaded from cache, skipping save"
)
logger
.
debug
(
"AOT compiled function was loaded from cache, skipping save"
)
return
return
...
@@ -608,6 +623,7 @@ def _support_torch_compile(
...
@@ -608,6 +623,7 @@ def _support_torch_compile(
tmp_file
=
f
"
{
self
.
_aot_compilation_path
}
.
{
os
.
getpid
()
}
.tmp"
tmp_file
=
f
"
{
self
.
_aot_compilation_path
}
.
{
os
.
getpid
()
}
.tmp"
self
.
aot_compiled_fn
.
save_compiled_function
(
tmp_file
)
self
.
aot_compiled_fn
.
save_compiled_function
(
tmp_file
)
os
.
replace
(
tmp_file
,
self
.
_aot_compilation_path
)
os
.
replace
(
tmp_file
,
self
.
_aot_compilation_path
)
compilation_counter
.
num_aot_artifacts_saved
+=
1
logger
.
info_once
(
logger
.
info_once
(
"saved AOT compiled function to %s"
,
"saved AOT compiled function to %s"
,
self
.
_aot_compilation_path
,
self
.
_aot_compilation_path
,
...
...
vllm/compilation/wrapper.py
View file @
09b6f998
...
@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
...
@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
compilation_counter
.
num_cache_entries_updated
=
0
compilation_counter
.
num_cache_entries_updated
=
0
compilation_counter
.
num_compiled_artifacts_saved
=
0
compilation_counter
.
num_compiled_artifacts_saved
=
0
compilation_counter
.
stock_torch_compile_count
=
0
compilation_counter
.
stock_torch_compile_count
=
0
compilation_counter
.
num_aot_compiles
=
0
compilation_counter
.
num_aot_artifacts_saved
=
0
compilation_counter
.
num_aot_artifacts_loaded
=
0
# Clear the AOT compiled function so the model is forced to
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
# recompile on the next call. Without this, decorators.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment