Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a97954b6
Unverified
Commit
a97954b6
authored
Mar 05, 2026
by
Zhengxu Chen
Committed by
GitHub
Mar 05, 2026
Browse files
[compile] Consistent compiler config for saved/loaded vllm backends. (#35810)
Signed-off-by:
zhxchen17
<
zhxchen17@fb.com
>
parent
a911f4dd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
8 deletions
+65
-8
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+42
-0
vllm/compilation/caching.py
vllm/compilation/caching.py
+23
-8
No files found.
tests/compile/test_aot_compile.py
View file @
a97954b6
...
@@ -14,6 +14,7 @@ import pytest
...
@@ -14,6 +14,7 @@ import pytest
import
torch
import
torch
import
vllm.model_executor.layers.activation
import
vllm.model_executor.layers.activation
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.caching
import
(
from
vllm.compilation.caching
import
(
StandaloneCompiledArtifacts
,
StandaloneCompiledArtifacts
,
VllmSerializableFunction
,
VllmSerializableFunction
,
...
@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
...
@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
(
"mod3"
,
"shape3"
),
(
"mod3"
,
"shape3"
),
]:
]:
assert
cache
.
get
(
submod
,
shape
)
==
shared_data
assert
cache
.
get
(
submod
,
shape
)
==
shared_data
def
test_functorch_config
(
self
):
vllm_config
=
make_vllm_config
()
example_inputs
=
(
torch
.
randn
(
10
,
10
),)
def
add_1
(
x
:
torch
.
Tensor
):
return
x
+
1
gm
=
torch
.
_dynamo
.
functional_export
.
dynamo_graph_capture_for_export
(
add_1
)(
*
example_inputs
)
gm
.
graph
.
_codegen
=
torch
.
fx
.
graph
.
CodeGen
()
gm
.
_dynamo_bytecode_flatten
=
None
gm
.
_dynamo_bytecode_unflatten
=
None
with
(
torch
.
_functorch
.
config
.
patch
(
bundled_autograd_cache
=
False
),
set_current_vllm_config
(
vllm_config
),
):
with
torch
.
_functorch
.
config
.
patch
(
bundled_autograd_cache
=
True
):
fn
=
VllmSerializableFunction
(
gm
,
example_inputs
,
""
,
add_1
)
payload
=
VllmSerializableFunction
.
serialize_compile_artifacts
(
fn
)
config
=
None
def
backend
(
*
args
,
**
kwargs
)
->
VllmSerializableFunction
:
nonlocal
config
# bundled_autograd_cache should be True even compiler backend
# runs with bundled_autograd_cache=False in ambient context.
config
=
torch
.
_functorch
.
config
.
save_config_portable
()
return
fn
loaded_fn
=
VllmSerializableFunction
.
deserialize_compile_artifacts
(
payload
)
with
patch
.
object
(
VllmBackend
,
"__call__"
,
backend
):
loaded_fn
(
*
example_inputs
)
assert
isinstance
(
config
,
dict
)
assert
"bundled_autograd_cache"
in
config
assert
config
[
"bundled_autograd_cache"
]
is
True
vllm/compilation/caching.py
View file @
a97954b6
...
@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
is_encoder
:
bool
=
False
,
is_encoder
:
bool
=
False
,
vllm_backend
:
Any
|
None
=
None
,
vllm_backend
:
Any
|
None
=
None
,
sym_tensor_indices
:
list
[
int
]
|
None
=
None
,
sym_tensor_indices
:
list
[
int
]
|
None
=
None
,
aot_autograd_config
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
None
:
)
->
None
:
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
assert
isinstance
(
graph_module
,
torch
.
fx
.
GraphModule
)
self
.
graph_module
=
graph_module
self
.
graph_module
=
graph_module
...
@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self
.
shape_env
=
None
self
.
shape_env
=
None
self
.
vllm_backend
=
vllm_backend
self
.
vllm_backend
=
vllm_backend
self
.
sym_tensor_indices
=
sym_tensor_indices
self
.
sym_tensor_indices
=
sym_tensor_indices
import
torch._functorch.config
as
functorch_config
self
.
aot_autograd_config
=
(
aot_autograd_config
or
functorch_config
.
save_config_portable
()
)
sym_input
=
next
(
sym_input
=
next
(
(
i
for
i
in
self
.
example_inputs
if
isinstance
(
i
,
torch
.
SymInt
)),
None
(
i
for
i
in
self
.
example_inputs
if
isinstance
(
i
,
torch
.
SymInt
)),
None
)
)
...
@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
sym_shape_indices_map
=
state
.
pop
(
"sym_shape_indices_map"
,
{})
sym_shape_indices_map
=
state
.
pop
(
"sym_shape_indices_map"
,
{})
returns_tuple_map
=
state
.
pop
(
"returns_tuple_map"
,
{})
returns_tuple_map
=
state
.
pop
(
"returns_tuple_map"
,
{})
saved_aot_autograd_config
=
state
[
"aot_autograd_config"
]
if
saved_aot_autograd_config
is
not
None
:
functorch_ctx
=
torch
.
_functorch
.
config
.
patch
(
saved_aot_autograd_config
)
else
:
functorch_ctx
=
contextlib
.
nullcontext
()
if
envs
.
VLLM_USE_MEGA_AOT_ARTIFACT
:
if
envs
.
VLLM_USE_MEGA_AOT_ARTIFACT
:
assert
standalone_compile_artifacts
is
not
None
assert
standalone_compile_artifacts
is
not
None
submod_names
=
standalone_compile_artifacts
.
submodule_names
()
submod_names
=
standalone_compile_artifacts
.
submodule_names
()
...
@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods
,
num_submods
,
)
)
fn
=
reconstruct_serializable_fn_from_mega_artifact
(
with
functorch_ctx
:
state
=
state
,
fn
=
reconstruct_serializable_fn_from_mega_artifact
(
standalone_compile_artifacts
=
standalone_compile_artifacts
,
state
=
state
,
vllm_config
=
get_current_vllm_config
(),
standalone_compile_artifacts
=
standalone_compile_artifacts
,
sym_shape_indices_map
=
sym_shape_indices_map
,
vllm_config
=
get_current_vllm_config
(),
returns_tuple_map
=
returns_tuple_map
,
sym_shape_indices_map
=
sym_shape_indices_map
,
)
returns_tuple_map
=
returns_tuple_map
,
)
logger
.
info
(
logger
.
info
(
"reconstructed serializable fn from standalone compile artifacts"
"reconstructed serializable fn from standalone compile artifacts"
...
@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_backend
:
VllmBackend
=
VllmBackend
(
vllm_backend
:
VllmBackend
=
VllmBackend
(
vllm_config
,
state
[
"prefix"
],
is_encoder
vllm_config
,
state
[
"prefix"
],
is_encoder
)
)
with
tracing
(
TracingContext
(
fake_mode
)):
with
tracing
(
TracingContext
(
fake_mode
))
,
functorch_ctx
:
fn
.
optimized_call
=
vllm_backend
(
fn
.
optimized_call
=
vllm_backend
(
state
[
"graph_module"
],
compile_inputs
state
[
"graph_module"
],
compile_inputs
).
optimized_call
).
optimized_call
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment