Unverified Commit 98a242ff authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Skip FX graph deserialiaztion on loading, further reducing warm compile time. (#40151)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent e4ee48da
......@@ -23,6 +23,10 @@ from torch._logging._internal import trace_structured
from torch.fx._lazy_graph_module import _use_lazy_graph_module
import vllm.envs as envs
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
......@@ -1244,11 +1248,6 @@ class VllmBackend:
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
)
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
)
execution_code, submod_names = generate_execution_code(self.split_gm)
# Use getattr to get correct callables: __dict__ has PiecewiseBackend
# instances (from PiecewiseCompileInterpreter), _modules has originals.
......
......@@ -16,6 +16,7 @@ from torch.fx._graph_pickler import GraphPickler, Options
from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.compilation.codegen import compile_execution_fn
from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig, get_current_vllm_config
......@@ -176,7 +177,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
def __init__(
self,
graph_module: torch.fx.GraphModule,
graph_module: torch.fx.GraphModule | bytes,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
......@@ -187,7 +188,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
execution_code: str | None = None,
submod_names: list[str] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.prefix = prefix
......@@ -302,10 +302,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
......@@ -331,6 +327,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
fake_mode=fake_mode,
)
logger.info(
......@@ -342,6 +339,11 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return fn
state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile()
# Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after
......@@ -410,6 +412,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool],
fake_mode: FakeTensorMode,
) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts.
......@@ -452,7 +455,6 @@ def reconstruct_serializable_fn_from_mega_artifact(
prefix = state["prefix"]
is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all()
......@@ -476,13 +478,16 @@ def reconstruct_serializable_fn_from_mega_artifact(
)
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
# if an old cache is used, this will fail but that's fine because
# we will just try this error and re-generate the new cache.
graph_children = set(state["submod_names"])
missing = set(piecewise_submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
submod_callables = {}
for i, submod_name in enumerate(piecewise_submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
......@@ -511,7 +516,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
is_last,
)
split_gm.__dict__[submod_name] = wrapped_backend
submod_callables[submod_name] = wrapped_backend
logger.debug(
"Replaced submodule %s with piecewise backend from cache",
submod_name,
......@@ -521,16 +526,16 @@ def reconstruct_serializable_fn_from_mega_artifact(
execution_code = state.get("execution_code")
submod_names = state.get("submod_names")
if execution_code is not None and submod_names is not None:
from vllm.compilation.codegen import compile_execution_fn
submod_callables = {
name: getattr(split_gm, name) for name, _ in split_gm.named_children()
}
runtime_callable = compile_execution_fn(
execution_code, submod_callables, submod_names
)
else:
runtime_callable = split_gm
logger.warning(
"No execution code found, falling back to graph module execution."
)
runtime_callable = GraphPickler.loads(
state["graph_module"], fake_mode=fake_mode
)
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
......
......@@ -15,26 +15,14 @@ from typing import Any
import torch.fx
from torch._dynamo.utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.node import _get_qualified_name
@dynamo_timed("vllm.generate_execution_code")
def generate_execution_code(
def generate_execution_code_with_name(
split_gm: torch.fx.GraphModule,
fn_name: str,
with_submod: bool,
) -> tuple[str, list[str]]:
"""Generate Python source code from a split_gm's stitching graph.
Walks split_gm.graph.nodes and produces a function that calls
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
and dict lookup cost.
Args:
split_gm: The split graph module produced by split_graph().
Returns:
A tuple of (code, submod_names) where code is the Python source
and submod_names is the ordered list of submodule target names
corresponding to list indices used in the generated code.
"""
lines: list[str] = []
param_names: list[str] = []
submod_names: list[str] = []
......@@ -43,6 +31,7 @@ def generate_execution_code(
# Build node ordering for liveness analysis.
nodes = list(split_gm.graph.nodes)
node_order = {node: i for i, node in enumerate(nodes)}
inlined_submods: list[str] = []
# For each value-producing node, find the position of its last consumer.
# If the last consumer is the output node, skip (return handles cleanup).
......@@ -65,6 +54,10 @@ def generate_execution_code(
elif node.op == "call_module":
target = node.target
if not with_submod:
raise RuntimeError(
f"call_module is not allowed for codegen target {target}."
)
if target not in submod_index:
submod_index[target] = len(submod_names)
submod_names.append(target)
......@@ -74,13 +67,32 @@ def generate_execution_code(
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
)
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(f" {node.name} = __vllm_submods__[{idx}]({all_args})")
elif node.op == "call_function" and node.target is operator.getitem:
source = _node_ref(node.args[0])
index = node.args[1]
assert isinstance(index, int)
lines.append(f" {node.name} = {source}[{index}]")
submod = getattr(split_gm, target)
if isinstance(submod, torch.fx.GraphModule):
callable_name = f"__vllm_inlined_submods__{idx}"
inlined_code, _ = generate_execution_code_with_name(
submod, callable_name, with_submod=False
)
inlined_submods.append(inlined_code)
else:
callable_name = f"__vllm_submods__[{idx}]"
lines.append(f" {node.name} = {callable_name}({all_args})")
elif node.op == "call_function":
if node.target is operator.getitem:
source = _node_ref(node.args[0])
index = node.args[1]
assert isinstance(index, int)
lines.append(f" {node.name} = {source}[{index}]")
else:
args_str = ", ".join(_node_ref(a) for a in node.args)
kwargs_str = ", ".join(
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
)
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(
f" {node.name} = {_get_qualified_name(node.target)}({all_args})"
)
elif node.op == "output":
assert len(node.args) == 1
......@@ -91,14 +103,44 @@ def generate_execution_code(
raise RuntimeError(f"Unsupported node from codegen: {node.format_node()}")
# Emit del for variables whose last use was this node.
if i in del_after:
if i in del_after and i < len(nodes) - 2:
names = sorted(del_after[i])
lines.append(f" del {', '.join(names)}")
assert len(param_names) > 0
params = ", ".join(param_names)
header = f"def execution_fn({params}, *, __vllm_submods__):"
return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names
header = (
f"\ndef {fn_name}({params}{', *, __vllm_submods__' if with_submod else ''}):"
)
return "".join(inlined_submods) + "\n".join([header] + lines) + "\n", submod_names
@dynamo_timed("vllm.generate_execution_code")
def generate_execution_code(
split_gm: torch.fx.GraphModule,
) -> tuple[str, list[str]]:
"""Generate Python source code from a split_gm's stitching graph.
Walks split_gm.graph.nodes and produces a function that calls
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
and dict lookup cost.
If a submodule is a plain torch.fx.GraphModule, it is inlined directly
in the generated code and we do not need to serialize it in the artifact.
Args:
split_gm: The split graph module produced by split_graph().
Returns:
A tuple of (code, submod_names) where code is the Python source
and submod_names is the ordered list of submodule target names
corresponding to list indices used in the generated code.
"""
code, submod_names = generate_execution_code_with_name(
split_gm, "execution_fn", with_submod=True
)
return "import torch\nimport operator\n" + code, submod_names
@dynamo_timed("vllm.compile_execution_fn")
......@@ -129,11 +171,12 @@ def compile_execution_fn(
namespace: dict[str, Any] = {}
exec(code, namespace) # noqa: S102
fn = namespace["execution_fn"]
# Use .forward() directly to avoid nn.Module.__call__ overhead.
submods_list = [
c.forward if isinstance(c, torch.fx.GraphModule) else c
for c in (submod_callables[name] for name in submod_names)
]
# Using .get() is intentional here because only piecewise backend will
# be stored in submod_callables. The other submodules are inlined and
# we don't need to bind them to the execution function. Instead, we
# should use None as placeholder to ensure the list indices are preserved
# for better debuggability.
submods_list = [submod_callables.get(name) for name in submod_names]
return partial(fn, __vllm_submods__=submods_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment