Unverified Commit c0f5fae6 authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Fix aot test failures with torch 2.12. (#37604)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent aa84e43c
...@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch ...@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
import vllm.envs as envs
import vllm.model_executor.layers.activation import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import ( from vllm.compilation.caching import (
...@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): ...@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch): def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
def foo(x: torch.Tensor): def foo(x: torch.Tensor):
return x[slice(0, x.shape[0])] return x[slice(0, x.shape[0])]
...@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch): ...@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
gm = torch.fx.symbolic_trace(foo) gm = torch.fx.symbolic_trace(foo)
assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
payload = VllmSerializableFunction.serialize_compile_artifacts( payload = VllmSerializableFunction.serialize_graph_module(gm)
VllmSerializableFunction(gm, (example_input,), "", foo) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
loaded_gm = VllmSerializableFunction.deserialize_graph_module(
payload, fake_mode
) )
fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
assert gm.code == fn.graph_module.code assert gm.code == loaded_gm.code
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
...@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration: ...@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration:
]: ]:
assert cache.get(submod, shape) == shared_data assert cache.get(submod, shape) == shared_data
@pytest.mark.skipif(
envs.VLLM_USE_MEGA_AOT_ARTIFACT,
reason="There's no AOT Autograd run with mega artifact",
)
def test_functorch_config(self): def test_functorch_config(self):
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),) example_inputs = (torch.randn(10, 10),)
......
...@@ -11,6 +11,8 @@ from typing import Any, Literal ...@@ -11,6 +11,8 @@ from typing import Any, Literal
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
import vllm.envs as envs import vllm.envs as envs
...@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return self.optimized_call(*args, **kwargs) return self.optimized_call(*args, **kwargs)
@classmethod @classmethod
def serialize_compile_artifacts( def serialize_graph_module(cls, graph_module: torch.fx.GraphModule) -> bytes:
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
import sympy import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
state.pop("_fake_mode", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override graph_reducer_override = GraphPickler.reducer_override
...@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return type(None), () return type(None), ()
return graph_reducer_override(self, obj) return graph_reducer_override(self, obj)
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
return GraphPickler.dumps(graph_module, Options(ops_filter=None))
@classmethod
def deserialize_graph_module(
cls, data: bytes, fake_mode: FakeTensorMode
) -> torch.fx.GraphModule:
with patch_pytree_map_over_slice():
return GraphPickler.loads(data, fake_mode)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
state.pop("_fake_mode", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
if state.get("sym_tensor_indices"): if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data # put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call # isn't needed, yet we need the meta for make_copy_and_call
...@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"), lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"], state["example_inputs"],
) )
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override), state["graph_module"] = cls.serialize_graph_module(state["graph_module"])
patch_pytree_map_over_slice(), state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend: if compiled_fn.vllm_backend:
( (
...@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
@classmethod @classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction": def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
from torch._guards import TracingContext, tracing from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.experimental.symbolic_shapes import ShapeEnv
state = pickle.loads(data) state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
with patch_pytree_map_over_slice():
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile() state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment