Unverified Commit 951dca80 authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Invoke split FX graph by codegen. (#38657)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent 5f7fab88
......@@ -1263,6 +1263,23 @@ class VllmBackend:
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
)
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
)
execution_code, submod_names = generate_execution_code(self.split_gm)
# Use getattr to get correct callables: __dict__ has PiecewiseBackend
# instances (from PiecewiseCompileInterpreter), _modules has originals.
# getattr checks __dict__ first, then falls back to _modules.
submod_callables = {
name: getattr(self.split_gm, name)
for name, _ in self.split_gm.named_children()
}
runtime_callable = compile_execution_fn(
execution_code, submod_callables, submod_names
)
if (
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs
......@@ -1271,9 +1288,11 @@ class VllmBackend:
graph_to_serialize,
example_inputs,
self.prefix,
self.split_gm,
runtime_callable,
is_encoder=self.is_encoder,
vllm_backend=self,
execution_code=execution_code,
submod_names=submod_names,
)
# index of tensors that have symbolic shapes (batch size)
......@@ -1294,7 +1313,7 @@ class VllmBackend:
copy_and_call = make_copy_and_call(
sym_tensor_indices,
[example_inputs[x].clone() for x in sym_tensor_indices],
self.split_gm,
runtime_callable,
)
return VllmSerializableFunction(
......@@ -1305,4 +1324,6 @@ class VllmBackend:
is_encoder=self.is_encoder,
vllm_backend=self,
sym_tensor_indices=sym_tensor_indices,
execution_code=execution_code,
submod_names=submod_names,
)
......@@ -184,6 +184,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
aot_autograd_config: dict[str, Any] | None = None,
execution_code: str | None = None,
submod_names: list[str] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
......@@ -194,6 +196,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
self.execution_code = execution_code
self.submod_names = submod_names
self._fake_mode: Any | None = None
import torch._functorch.config as functorch_config
......@@ -453,7 +457,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
piecewise_submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
......@@ -473,13 +477,13 @@ def reconstruct_serializable_fn_from_mega_artifact(
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
missing = set(submod_names) - graph_children
missing = set(piecewise_submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
for i, submod_name in enumerate(submod_names):
for i, submod_name in enumerate(piecewise_submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
sym_shape_indices = sym_shape_indices_map[submod_name]
......@@ -490,7 +494,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
graph=None, # not needed for cached artifacts
vllm_config=vllm_config,
piecewise_compile_index=i,
total_piecewise_compiles=len(submod_names),
total_piecewise_compiles=len(piecewise_submod_names),
sym_shape_indices=sym_shape_indices,
vllm_backend=vllm_backend,
returns_tuple=returns_tuple,
......@@ -498,7 +502,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
)
is_first = i == 0
is_last = i == len(submod_names) - 1
is_last = i == len(piecewise_submod_names) - 1
wrapped_backend = wrap_with_cudagraph_if_needed(
piecewise_backend,
vllm_config,
......@@ -513,6 +517,21 @@ def reconstruct_serializable_fn_from_mega_artifact(
submod_name,
)
# Use codegen'd execution code if available, fall back to split_gm
execution_code = state.get("execution_code")
submod_names = state.get("submod_names")
if execution_code is not None and submod_names is not None:
from vllm.compilation.codegen import compile_execution_fn
submod_callables = {
name: getattr(split_gm, name) for name, _ in split_gm.named_children()
}
runtime_callable = compile_execution_fn(
execution_code, submod_callables, submod_names
)
else:
runtime_callable = split_gm
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
input_buffers = [
......@@ -521,9 +540,11 @@ def reconstruct_serializable_fn_from_mega_artifact(
)
for idx in sym_tensor_indices
]
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
optimized_call = make_copy_and_call(
sym_tensor_indices, input_buffers, runtime_callable
)
else:
optimized_call = split_gm
optimized_call = runtime_callable
fn = VllmSerializableFunction(
**state,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Code generation for split_gm stitching graph execution.
Generates a plain Python function that replaces the FX GraphModule's
interpreter-based execution of the stitching graph, eliminating
nn.Module.__call__ overhead and __getattr__ dispatch.
"""
import operator
from collections.abc import Callable
from functools import partial
from typing import Any
import torch.fx
from torch._dynamo.utils import dynamo_timed
from torch._logging import trace_structured
@dynamo_timed("vllm.generate_execution_code")
def generate_execution_code(
split_gm: torch.fx.GraphModule,
) -> tuple[str, list[str]]:
"""Generate Python source code from a split_gm's stitching graph.
Walks split_gm.graph.nodes and produces a function that calls
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
and dict lookup cost.
Args:
split_gm: The split graph module produced by split_graph().
Returns:
A tuple of (code, submod_names) where code is the Python source
and submod_names is the ordered list of submodule target names
corresponding to list indices used in the generated code.
"""
lines: list[str] = []
param_names: list[str] = []
submod_names: list[str] = []
submod_index: dict[str, int] = {}
# Build node ordering for liveness analysis.
nodes = list(split_gm.graph.nodes)
node_order = {node: i for i, node in enumerate(nodes)}
# For each value-producing node, find the position of its last consumer.
# If the last consumer is the output node, skip (return handles cleanup).
# Otherwise, schedule a del after that consumer to free memory early.
del_after: dict[int, list[str]] = {} # position -> names to delete
for node in nodes:
if node.op == "output":
continue
users = list(node.users.keys())
if not users:
continue
last_user = max(users, key=lambda u: node_order[u])
if last_user.op == "output":
continue
del_after.setdefault(node_order[last_user], []).append(node.name)
for i, node in enumerate(nodes):
if node.op == "placeholder":
param_names.append(node.name)
elif node.op == "call_module":
target = node.target
if target not in submod_index:
submod_index[target] = len(submod_names)
submod_names.append(target)
idx = submod_index[target]
args_str = ", ".join(_node_ref(a) for a in node.args)
kwargs_str = ", ".join(
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
)
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(f" {node.name} = __vllm_submods__[{idx}]({all_args})")
elif node.op == "call_function" and node.target is operator.getitem:
source = _node_ref(node.args[0])
index = node.args[1]
assert isinstance(index, int)
lines.append(f" {node.name} = {source}[{index}]")
elif node.op == "output":
assert len(node.args) == 1
ret = _node_ref(node.args[0])
lines.append(f" return {ret}")
else:
raise RuntimeError(f"Unsupported node from codegen: {node.format_node()}")
# Emit del for variables whose last use was this node.
if i in del_after:
names = sorted(del_after[i])
lines.append(f" del {', '.join(names)}")
assert len(param_names) > 0
params = ", ".join(param_names)
header = f"def execution_fn({params}, *, __vllm_submods__):"
return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names
@dynamo_timed("vllm.compile_execution_fn")
def compile_execution_fn(
code: str,
submod_callables: dict[str, Callable[..., Any]],
submod_names: list[str],
) -> Callable[..., Any]:
"""Compile execution code and bind submodule callables.
Args:
code: Python source from generate_execution_code().
submod_callables: Mapping of submodule names to their callables.
submod_names: Ordered list of submodule names matching the indices
used in the generated code.
Returns:
A callable that executes the stitching logic.
"""
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_execution_code",
"encoding": "string",
},
payload_fn=lambda: code,
)
namespace: dict[str, Any] = {}
exec(code, namespace) # noqa: S102
fn = namespace["execution_fn"]
# Use .forward() directly to avoid nn.Module.__call__ overhead.
submods_list = [
c.forward if isinstance(c, torch.fx.GraphModule) else c
for c in (submod_callables[name] for name in submod_names)
]
return partial(fn, __vllm_submods__=submods_list)
def _node_ref(arg: Any) -> str:
"""Convert an FX node argument to a source code reference recursively."""
if isinstance(arg, torch.fx.Node):
return arg.name
if isinstance(arg, list):
return f"[{', '.join(_node_ref(x) for x in arg)}]"
if isinstance(arg, tuple):
items = ", ".join(_node_ref(x) for x in arg)
return f"({items},)" if len(arg) == 1 else f"({items})"
if isinstance(arg, dict):
return (
"{"
+ ", ".join(f"{_node_ref(k)}: {_node_ref(v)}" for k, v in arg.items())
+ "}"
)
return repr(arg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment