Unverified Commit a97954b6 authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Consistent compiler config for saved/loaded vllm backends. (#35810)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent a911f4dd
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
import torch import torch
import vllm.model_executor.layers.activation import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import ( from vllm.compilation.caching import (
StandaloneCompiledArtifacts, StandaloneCompiledArtifacts,
VllmSerializableFunction, VllmSerializableFunction,
...@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration: ...@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
("mod3", "shape3"), ("mod3", "shape3"),
]: ]:
assert cache.get(submod, shape) == shared_data assert cache.get(submod, shape) == shared_data
def test_functorch_config(self):
vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),)
def add_1(x: torch.Tensor):
return x + 1
gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)(
*example_inputs
)
gm.graph._codegen = torch.fx.graph.CodeGen()
gm._dynamo_bytecode_flatten = None
gm._dynamo_bytecode_unflatten = None
with (
torch._functorch.config.patch(bundled_autograd_cache=False),
set_current_vllm_config(vllm_config),
):
with torch._functorch.config.patch(bundled_autograd_cache=True):
fn = VllmSerializableFunction(gm, example_inputs, "", add_1)
payload = VllmSerializableFunction.serialize_compile_artifacts(fn)
config = None
def backend(*args, **kwargs) -> VllmSerializableFunction:
nonlocal config
# bundled_autograd_cache should be True even compiler backend
# runs with bundled_autograd_cache=False in ambient context.
config = torch._functorch.config.save_config_portable()
return fn
loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
with patch.object(VllmBackend, "__call__", backend):
loaded_fn(*example_inputs)
assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True
...@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
is_encoder: bool = False, is_encoder: bool = False,
vllm_backend: Any | None = None, vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None, sym_tensor_indices: list[int] | None = None,
aot_autograd_config: dict[str, Any] | None = None,
) -> None: ) -> None:
assert isinstance(graph_module, torch.fx.GraphModule) assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
...@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.shape_env = None self.shape_env = None
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices self.sym_tensor_indices = sym_tensor_indices
import torch._functorch.config as functorch_config
self.aot_autograd_config = (
aot_autograd_config or functorch_config.save_config_portable()
)
sym_input = next( sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
) )
...@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
sym_shape_indices_map = state.pop("sym_shape_indices_map", {}) sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {}) returns_tuple_map = state.pop("returns_tuple_map", {})
saved_aot_autograd_config = state["aot_autograd_config"]
if saved_aot_autograd_config is not None:
functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
else:
functorch_ctx = contextlib.nullcontext()
if envs.VLLM_USE_MEGA_AOT_ARTIFACT: if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names() submod_names = standalone_compile_artifacts.submodule_names()
...@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods, num_submods,
) )
fn = reconstruct_serializable_fn_from_mega_artifact( with functorch_ctx:
state=state, fn = reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts=standalone_compile_artifacts, state=state,
vllm_config=get_current_vllm_config(), standalone_compile_artifacts=standalone_compile_artifacts,
sym_shape_indices_map=sym_shape_indices_map, vllm_config=get_current_vllm_config(),
returns_tuple_map=returns_tuple_map, sym_shape_indices_map=sym_shape_indices_map,
) returns_tuple_map=returns_tuple_map,
)
logger.info( logger.info(
"reconstructed serializable fn from standalone compile artifacts" "reconstructed serializable fn from standalone compile artifacts"
...@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_backend: VllmBackend = VllmBackend( vllm_backend: VllmBackend = VllmBackend(
vllm_config, state["prefix"], is_encoder vllm_config, state["prefix"], is_encoder
) )
with tracing(TracingContext(fake_mode)): with tracing(TracingContext(fake_mode)), functorch_ctx:
fn.optimized_call = vllm_backend( fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs state["graph_module"], compile_inputs
).optimized_call ).optimized_call
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment