Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c0f5fae6
Unverified
Commit
c0f5fae6
authored
Mar 20, 2026
by
Zhengxu Chen
Committed by
GitHub
Mar 20, 2026
Browse files
[compile] Fix aot test failures with torch 2.12. (#37604)
Signed-off-by:
zhxchen17
<
zhxchen17@fb.com
>
parent
aa84e43c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
35 deletions
+54
-35
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+13
-4
vllm/compilation/caching.py
vllm/compilation/caching.py
+41
-31
No files found.
tests/compile/test_aot_compile.py
View file @
c0f5fae6
...
@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch
...
@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch
import
pytest
import
pytest
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.activation
import
vllm.model_executor.layers.activation
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.caching
import
(
from
vllm.compilation.caching
import
(
...
@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
...
@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_save_and_load_slice
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_save_and_load_slice
(
monkeypatch
:
pytest
.
MonkeyPatch
):
from
torch._subclasses
import
FakeTensorMode
from
torch.fx.experimental.symbolic_shapes
import
ShapeEnv
def
foo
(
x
:
torch
.
Tensor
):
def
foo
(
x
:
torch
.
Tensor
):
return
x
[
slice
(
0
,
x
.
shape
[
0
])]
return
x
[
slice
(
0
,
x
.
shape
[
0
])]
...
@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
...
@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
gm
=
torch
.
fx
.
symbolic_trace
(
foo
)
gm
=
torch
.
fx
.
symbolic_trace
(
foo
)
assert
"getitem_1 = x[slice(0, getitem, None)]"
in
gm
.
code
assert
"getitem_1 = x[slice(0, getitem, None)]"
in
gm
.
code
with
use_vllm_config
(
vllm_config
):
with
use_vllm_config
(
vllm_config
):
payload
=
VllmSerializableFunction
.
serialize_compile_artifacts
(
payload
=
VllmSerializableFunction
.
serialize_graph_module
(
gm
)
VllmSerializableFunction
(
gm
,
(
example_input
,),
""
,
foo
)
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
loaded_gm
=
VllmSerializableFunction
.
deserialize_graph_module
(
payload
,
fake_mode
)
)
fn
=
VllmSerializableFunction
.
deserialize_compile_artifacts
(
payload
)
assert
gm
.
code
==
fn
.
graph_module
.
code
assert
gm
.
code
==
loaded_gm
.
code
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
...
@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration:
...
@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration:
]:
]:
assert
cache
.
get
(
submod
,
shape
)
==
shared_data
assert
cache
.
get
(
submod
,
shape
)
==
shared_data
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_USE_MEGA_AOT_ARTIFACT
,
reason
=
"There's no AOT Autograd run with mega artifact"
,
)
def
test_functorch_config
(
self
):
def
test_functorch_config
(
self
):
vllm_config
=
make_vllm_config
()
vllm_config
=
make_vllm_config
()
example_inputs
=
(
torch
.
randn
(
10
,
10
),)
example_inputs
=
(
torch
.
randn
(
10
,
10
),)
...
...
vllm/compilation/caching.py
View file @
c0f5fae6
...
@@ -11,6 +11,8 @@ from typing import Any, Literal
...
@@ -11,6 +11,8 @@ from typing import Any, Literal
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
from
torch._subclasses
import
FakeTensorMode
from
torch.fx._graph_pickler
import
GraphPickler
,
Options
from
torch.utils
import
_pytree
as
pytree
from
torch.utils
import
_pytree
as
pytree
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return
self
.
optimized_call
(
*
args
,
**
kwargs
)
return
self
.
optimized_call
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
def
serialize_compile_artifacts
(
def
serialize_graph_module
(
cls
,
graph_module
:
torch
.
fx
.
GraphModule
)
->
bytes
:
cls
,
compiled_fn
:
"VllmSerializableFunction"
)
->
bytes
:
import
sympy
import
sympy
from
torch._subclasses
import
FakeTensorMode
from
torch.fx._graph_pickler
import
GraphPickler
,
Options
state
=
compiled_fn
.
__dict__
.
copy
()
state
.
pop
(
"optimized_call"
)
state
.
pop
(
"shape_env"
)
state
.
pop
(
"vllm_backend"
,
None
)
state
.
pop
(
"_fake_mode"
,
None
)
for
node
in
state
[
"graph_module"
].
graph
.
nodes
:
node
.
meta
.
pop
(
"source_fn_stack"
,
None
)
node
.
meta
.
pop
(
"nn_module_stack"
,
None
)
for
name
,
submod
in
state
[
"graph_module"
].
named_children
():
if
hasattr
(
submod
,
"graph"
):
for
node
in
submod
.
graph
.
nodes
:
node
.
meta
.
pop
(
"source_fn_stack"
,
None
)
node
.
meta
.
pop
(
"nn_module_stack"
,
None
)
graph_reducer_override
=
GraphPickler
.
reducer_override
graph_reducer_override
=
GraphPickler
.
reducer_override
...
@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return
type
(
None
),
()
return
type
(
None
),
()
return
graph_reducer_override
(
self
,
obj
)
return
graph_reducer_override
(
self
,
obj
)
with
(
patch
.
object
(
GraphPickler
,
"reducer_override"
,
_graph_reducer_override
),
patch_pytree_map_over_slice
(),
):
return
GraphPickler
.
dumps
(
graph_module
,
Options
(
ops_filter
=
None
))
@
classmethod
def
deserialize_graph_module
(
cls
,
data
:
bytes
,
fake_mode
:
FakeTensorMode
)
->
torch
.
fx
.
GraphModule
:
with
patch_pytree_map_over_slice
():
return
GraphPickler
.
loads
(
data
,
fake_mode
)
@
classmethod
def
serialize_compile_artifacts
(
cls
,
compiled_fn
:
"VllmSerializableFunction"
)
->
bytes
:
state
=
compiled_fn
.
__dict__
.
copy
()
state
.
pop
(
"optimized_call"
)
state
.
pop
(
"shape_env"
)
state
.
pop
(
"vllm_backend"
,
None
)
state
.
pop
(
"_fake_mode"
,
None
)
for
node
in
state
[
"graph_module"
].
graph
.
nodes
:
node
.
meta
.
pop
(
"source_fn_stack"
,
None
)
node
.
meta
.
pop
(
"nn_module_stack"
,
None
)
for
name
,
submod
in
state
[
"graph_module"
].
named_children
():
if
hasattr
(
submod
,
"graph"
):
for
node
in
submod
.
graph
.
nodes
:
node
.
meta
.
pop
(
"source_fn_stack"
,
None
)
node
.
meta
.
pop
(
"nn_module_stack"
,
None
)
if
state
.
get
(
"sym_tensor_indices"
):
if
state
.
get
(
"sym_tensor_indices"
):
# put tensor inputs on meta device since their data
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
# isn't needed, yet we need the meta for make_copy_and_call
...
@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda
inp
:
torch
.
empty_like
(
inp
,
device
=
"meta"
),
lambda
inp
:
torch
.
empty_like
(
inp
,
device
=
"meta"
),
state
[
"example_inputs"
],
state
[
"example_inputs"
],
)
)
with
(
patch
.
object
(
GraphPickler
,
"reducer_override"
,
_graph_reducer_override
),
state
[
"graph_module"
]
=
cls
.
serialize_graph_module
(
state
[
"graph_module"
])
patch_pytree_map_over_slice
(),
state
[
"example_inputs"
]
=
GraphPickler
.
dumps
(
state
[
"example_inputs"
])
):
state
[
"graph_module"
]
=
GraphPickler
.
dumps
(
state
[
"graph_module"
],
Options
(
ops_filter
=
None
)
)
state
[
"example_inputs"
]
=
GraphPickler
.
dumps
(
state
[
"example_inputs"
])
if
compiled_fn
.
vllm_backend
:
if
compiled_fn
.
vllm_backend
:
(
(
...
@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
@
classmethod
@
classmethod
def
deserialize_compile_artifacts
(
cls
,
data
:
bytes
)
->
"VllmSerializableFunction"
:
def
deserialize_compile_artifacts
(
cls
,
data
:
bytes
)
->
"VllmSerializableFunction"
:
from
torch._guards
import
TracingContext
,
tracing
from
torch._guards
import
TracingContext
,
tracing
from
torch._subclasses
import
FakeTensorMode
from
torch.fx._graph_pickler
import
GraphPickler
from
torch.fx.experimental.symbolic_shapes
import
ShapeEnv
from
torch.fx.experimental.symbolic_shapes
import
ShapeEnv
state
=
pickle
.
loads
(
data
)
state
=
pickle
.
loads
(
data
)
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
with
patch_pytree_map_over_slice
():
state
[
"graph_module"
]
=
GraphPickler
.
loads
(
state
[
"graph_module"
],
fake_mode
)
state
[
"graph_module"
]
=
cls
.
deserialize_graph_module
(
state
[
"graph_module"
],
fake_mode
)
state
[
"graph_module"
].
recompile
()
state
[
"graph_module"
].
recompile
()
state
[
"example_inputs"
]
=
GraphPickler
.
loads
(
state
[
"example_inputs"
],
fake_mode
)
state
[
"example_inputs"
]
=
GraphPickler
.
loads
(
state
[
"example_inputs"
],
fake_mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment