Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1f0d1845
Unverified
Commit
1f0d1845
authored
Dec 04, 2025
by
Laith Sakka
Committed by
GitHub
Dec 04, 2025
Browse files
[aot_compile]change VLLM backend to read fake args from example_value (#29104)
Signed-off-by:
Laith Sakka
<
lsakka@meta.com
>
parent
c8ab988b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
10 deletions
+81
-10
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+66
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+15
-9
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+0
-1
No files found.
tests/compile/test_aot_compile.py
View file @
1f0d1845
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
multiprocessing
import
tempfile
import
tempfile
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
...
@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
artifacts
=
compiled_mod
.
aot_compiled_fn
.
_artifacts
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
guards_string
=
artifacts
.
compiled_fn
.
shape_env
.
format_guards
()
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
use_vllm_config
(
make_vllm_config
())
def
test_gpt2_cache_hit
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that compiling gpt2 twice results in a cache hit and
capture torch dynamic symbol creations to ensure make_symbol
not called on cache hit.
"""
import
torch.fx.experimental.symbolic_shapes
as
symbolic_shapes_module
from
torch.utils._sympy.symbol
import
make_symbol
from
vllm
import
LLM
create_symbol_counter
=
multiprocessing
.
Value
(
"i"
,
0
)
original_make_symbol
=
make_symbol
@
functools
.
wraps
(
original_make_symbol
)
def
counting_make_symbol
(
prefix
,
idx
,
**
kwargs
):
with
create_symbol_counter
.
get_lock
():
create_symbol_counter
.
value
+=
1
return
original_make_symbol
(
prefix
,
idx
,
**
kwargs
)
symbolic_shapes_module
.
make_symbol
=
counting_make_symbol
try
:
with
monkeypatch
.
context
()
as
m
,
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
m
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
m
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
"1"
)
# First compilation - initialize model and generate
llm_model
=
LLM
(
model
=
"gpt2"
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
),
max_model_len
=
256
,
)
llm_model
.
generate
(
"Hello, my name is"
)
assert
create_symbol_counter
.
value
==
2
create_symbol_counter
.
value
=
0
# Clean up first model
del
llm_model
# Second compilation - should hit cache
m
.
setenv
(
"VLLM_FORCE_AOT_LOAD"
,
"1"
)
llm_model
=
LLM
(
model
=
"gpt2"
,
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
),
max_model_len
=
256
,
)
llm_model
.
generate
(
"Hello, my name is"
)
assert
create_symbol_counter
.
value
==
0
finally
:
# Restore original method
symbolic_shapes_module
.
make_symbol
=
original_make_symbol
vllm/compilation/backends.py
View file @
1f0d1845
...
@@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self
.
extra_traceback
=
False
self
.
extra_traceback
=
False
def
run
(
self
,
*
args
):
def
run
(
self
,
*
args
):
# maybe instead just assert inputs are fake?
fake_args
=
[
fake_args
=
[
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
self
.
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
t
in
args
for
t
in
args
...
@@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
...
@@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
kwargs
:
dict
[
str
,
Any
],
kwargs
:
dict
[
str
,
Any
],
)
->
Any
:
)
->
Any
:
assert
isinstance
(
target
,
str
)
assert
isinstance
(
target
,
str
)
output
=
super
().
call_module
(
target
,
args
,
kwargs
)
output
=
super
().
call_module
(
target
,
args
,
kwargs
)
if
target
in
self
.
compile_submod_names
:
if
target
in
self
.
compile_submod_names
:
index
=
self
.
compile_submod_names
.
index
(
target
)
index
=
self
.
compile_submod_names
.
index
(
target
)
submod
=
self
.
fetch_attr
(
target
)
submod
=
self
.
fetch_attr
(
target
)
sym_shape_indices
=
[
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
i
for
i
,
x
in
enumerate
(
args
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
]
...
@@ -746,11 +749,21 @@ class VllmBackend:
...
@@ -746,11 +749,21 @@ class VllmBackend:
if
not
item
.
is_splitting_graph
if
not
item
.
is_splitting_graph
]
]
# Extract fake values from the graph to use them when needed.
all_fake_values
=
[]
for
i
in
graph
.
graph
.
find_nodes
(
op
=
"placeholder"
):
all_fake_values
.
append
(
i
.
meta
[
"example_value"
])
fake_args
=
[
all_fake_values
[
i
]
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
i
,
t
in
enumerate
(
example_inputs
)
]
# propagate the split graph to the piecewise backend,
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter
(
PiecewiseCompileInterpreter
(
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
).
run
(
*
example_input
s
)
).
run
(
*
fake_arg
s
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
if
not
os
.
path
.
exists
(
graph_path
):
if
not
os
.
path
.
exists
(
graph_path
):
...
@@ -780,14 +793,7 @@ class VllmBackend:
...
@@ -780,14 +793,7 @@ class VllmBackend:
)
)
# if we need to copy input buffers for cudagraph
# if we need to copy input buffers for cudagraph
from
torch._guards
import
detect_fake_mode
#
fake_mode
=
detect_fake_mode
()
fake_args
=
[
fake_mode
.
from_tensor
(
t
)
if
isinstance
(
t
,
torch
.
Tensor
)
else
t
for
t
in
example_inputs
]
# index of tensors that have symbolic shapes (batch size)
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
# symbolic shape only happens for input tensors.
...
...
vllm/compilation/decorators.py
View file @
1f0d1845
...
@@ -433,7 +433,6 @@ def _support_torch_compile(
...
@@ -433,7 +433,6 @@ def _support_torch_compile(
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
return
TorchCompileWithNoGuardsWrapper
.
__call__
(
self
,
*
args
,
**
kwargs
)
# This is the path for the first compilation.
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs
(
_mark_dynamic_inputs
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment