Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
29b35477
Unverified
Commit
29b35477
authored
Feb 27, 2026
by
Zhengxu Chen
Committed by
GitHub
Feb 27, 2026
Browse files
[compile] Fix caching error over pytree slice node. (#35308)
Signed-off-by:
zhxchen17
<
zhxchen17@fb.com
>
parent
b1d9f537
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
2 deletions
+40
-2
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+21
-0
vllm/compilation/caching.py
vllm/compilation/caching.py
+19
-2
No files found.
tests/compile/test_aot_compile.py
View file @
29b35477
...
@@ -16,6 +16,7 @@ import torch
...
@@ -16,6 +16,7 @@ import torch
import
vllm.model_executor.layers.activation
import
vllm.model_executor.layers.activation
from
vllm.compilation.caching
import
(
from
vllm.compilation.caching
import
(
StandaloneCompiledArtifacts
,
StandaloneCompiledArtifacts
,
VllmSerializableFunction
,
)
)
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
from
vllm.config
import
(
...
@@ -156,6 +157,26 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
...
@@ -156,6 +157,26 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
assert
torch
.
allclose
(
ret
,
expected
)
assert
torch
.
allclose
(
ret
,
expected
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_save_and_load_slice
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
foo
(
x
:
torch
.
Tensor
):
return
x
[
slice
(
0
,
x
.
shape
[
0
])]
vllm_config
=
make_vllm_config
()
example_input
=
torch
.
randn
(
10
,
10
)
torch
.
_dynamo
.
mark_dynamic
(
example_input
,
0
)
gm
=
torch
.
fx
.
symbolic_trace
(
foo
)
assert
"getitem_1 = x[slice(0, getitem, None)]"
in
gm
.
code
with
use_vllm_config
(
vllm_config
):
payload
=
VllmSerializableFunction
.
serialize_compile_artifacts
(
VllmSerializableFunction
(
gm
,
(
example_input
,),
""
,
foo
)
)
fn
=
VllmSerializableFunction
.
deserialize_compile_artifacts
(
payload
)
assert
gm
.
code
==
fn
.
graph_module
.
code
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_cache_load_returns_tuple_consistency
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_cache_load_returns_tuple_consistency
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
"""
...
...
vllm/compilation/caching.py
View file @
29b35477
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
hashlib
import
hashlib
import
inspect
import
inspect
import
os
import
os
...
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
...
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
self
.
loaded_submodule_store
=
{}
self
.
loaded_submodule_store
=
{}
@
contextlib
.
contextmanager
def
patch_pytree_map_over_slice
():
pytree
.
_private_register_pytree_node
(
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
,
c
:
slice
(
*
x
)
)
try
:
yield
finally
:
pytree
.
_deregister_pytree_node
(
slice
)
class
VllmSerializableFunction
(
SerializableCallable
):
# type: ignore[misc]
class
VllmSerializableFunction
(
SerializableCallable
):
# type: ignore[misc]
"""
"""
A wrapper around a compiled function by vllm. It will forward the tensor
A wrapper around a compiled function by vllm. It will forward the tensor
...
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda
inp
:
torch
.
empty_like
(
inp
,
device
=
"meta"
),
lambda
inp
:
torch
.
empty_like
(
inp
,
device
=
"meta"
),
state
[
"example_inputs"
],
state
[
"example_inputs"
],
)
)
with
patch
.
object
(
GraphPickler
,
"reducer_override"
,
_graph_reducer_override
):
with
(
patch
.
object
(
GraphPickler
,
"reducer_override"
,
_graph_reducer_override
),
patch_pytree_map_over_slice
(),
):
state
[
"graph_module"
]
=
GraphPickler
.
dumps
(
state
[
"graph_module"
]
=
GraphPickler
.
dumps
(
state
[
"graph_module"
],
Options
(
ops_filter
=
None
)
state
[
"graph_module"
],
Options
(
ops_filter
=
None
)
)
)
...
@@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
...
@@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state
=
pickle
.
loads
(
data
)
state
=
pickle
.
loads
(
data
)
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
fake_mode
=
FakeTensorMode
(
shape_env
=
ShapeEnv
())
state
[
"graph_module"
]
=
GraphPickler
.
loads
(
state
[
"graph_module"
],
fake_mode
)
with
patch_pytree_map_over_slice
():
state
[
"graph_module"
]
=
GraphPickler
.
loads
(
state
[
"graph_module"
],
fake_mode
)
state
[
"graph_module"
].
recompile
()
state
[
"graph_module"
].
recompile
()
state
[
"example_inputs"
]
=
GraphPickler
.
loads
(
state
[
"example_inputs"
],
fake_mode
)
state
[
"example_inputs"
]
=
GraphPickler
.
loads
(
state
[
"example_inputs"
],
fake_mode
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment