Unverified Commit 98a242ff authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Skip FX graph deserialiaztion on loading, further reducing warm compile time. (#40151)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent e4ee48da
...@@ -23,6 +23,10 @@ from torch._logging._internal import trace_structured ...@@ -23,6 +23,10 @@ from torch._logging._internal import trace_structured
from torch.fx._lazy_graph_module import _use_lazy_graph_module from torch.fx._lazy_graph_module import _use_lazy_graph_module
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors from vllm.config.utils import Range, hash_factors
...@@ -1244,11 +1248,6 @@ class VllmBackend: ...@@ -1244,11 +1248,6 @@ class VllmBackend:
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
) )
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
)
execution_code, submod_names = generate_execution_code(self.split_gm) execution_code, submod_names = generate_execution_code(self.split_gm)
# Use getattr to get correct callables: __dict__ has PiecewiseBackend # Use getattr to get correct callables: __dict__ has PiecewiseBackend
# instances (from PiecewiseCompileInterpreter), _modules has originals. # instances (from PiecewiseCompileInterpreter), _modules has originals.
......
...@@ -16,6 +16,7 @@ from torch.fx._graph_pickler import GraphPickler, Options ...@@ -16,6 +16,7 @@ from torch.fx._graph_pickler import GraphPickler, Options
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.codegen import compile_execution_fn
from vllm.compilation.compiler_interface import get_inductor_factors from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
...@@ -176,7 +177,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -176,7 +177,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
def __init__( def __init__(
self, self,
graph_module: torch.fx.GraphModule, graph_module: torch.fx.GraphModule | bytes,
example_inputs: Sequence[Any], example_inputs: Sequence[Any],
prefix: str, prefix: str,
optimized_call: Callable[..., Any], optimized_call: Callable[..., Any],
...@@ -187,7 +188,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -187,7 +188,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
execution_code: str | None = None, execution_code: str | None = None,
submod_names: list[str] | None = None, submod_names: list[str] | None = None,
) -> None: ) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
self.example_inputs = example_inputs self.example_inputs = example_inputs
self.prefix = prefix self.prefix = prefix
...@@ -302,10 +302,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -302,10 +302,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data) state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None) standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
...@@ -331,6 +327,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -331,6 +327,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_config=get_current_vllm_config(), vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map, sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map, returns_tuple_map=returns_tuple_map,
fake_mode=fake_mode,
) )
logger.info( logger.info(
...@@ -342,6 +339,11 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -342,6 +339,11 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return fn return fn
state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile()
# Fall back to standard VllmBackend. # Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache # Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after # dir computation, but those are only populated after
...@@ -410,6 +412,7 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -410,6 +412,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
vllm_config: VllmConfig, vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]], sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool], returns_tuple_map: dict[str, bool],
fake_mode: FakeTensorMode,
) -> "VllmSerializableFunction": ) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts. """Construct a VllmSerializableFunction from cached inductor artifacts.
...@@ -452,7 +455,6 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -452,7 +455,6 @@ def reconstruct_serializable_fn_from_mega_artifact(
prefix = state["prefix"] prefix = state["prefix"]
is_encoder = state.get("is_encoder", False) is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all() standalone_compile_artifacts.load_all()
...@@ -476,13 +478,16 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -476,13 +478,16 @@ def reconstruct_serializable_fn_from_mega_artifact(
) )
# spot check that cached submodules exist in the graph structure # spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()} # if an old cache is used, this will fail but that's fine because
# we will just try this error and re-generate the new cache.
graph_children = set(state["submod_names"])
missing = set(piecewise_submod_names) - graph_children missing = set(piecewise_submod_names) - graph_children
assert not missing, ( assert not missing, (
f"artifacts reference submodules not in graph: {missing}. " f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}" f"graph has: {sorted(graph_children)}"
) )
submod_callables = {}
for i, submod_name in enumerate(piecewise_submod_names): for i, submod_name in enumerate(piecewise_submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
...@@ -511,7 +516,7 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -511,7 +516,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
is_last, is_last,
) )
split_gm.__dict__[submod_name] = wrapped_backend submod_callables[submod_name] = wrapped_backend
logger.debug( logger.debug(
"Replaced submodule %s with piecewise backend from cache", "Replaced submodule %s with piecewise backend from cache",
submod_name, submod_name,
...@@ -521,16 +526,16 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -521,16 +526,16 @@ def reconstruct_serializable_fn_from_mega_artifact(
execution_code = state.get("execution_code") execution_code = state.get("execution_code")
submod_names = state.get("submod_names") submod_names = state.get("submod_names")
if execution_code is not None and submod_names is not None: if execution_code is not None and submod_names is not None:
from vllm.compilation.codegen import compile_execution_fn
submod_callables = {
name: getattr(split_gm, name) for name, _ in split_gm.named_children()
}
runtime_callable = compile_execution_fn( runtime_callable = compile_execution_fn(
execution_code, submod_callables, submod_names execution_code, submod_callables, submod_names
) )
else: else:
runtime_callable = split_gm logger.warning(
"No execution code found, falling back to graph module execution."
)
runtime_callable = GraphPickler.loads(
state["graph_module"], fake_mode=fake_mode
)
if compilation_config.cudagraph_copy_inputs: if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"] sym_tensor_indices = state["sym_tensor_indices"]
......
...@@ -15,26 +15,14 @@ from typing import Any ...@@ -15,26 +15,14 @@ from typing import Any
import torch.fx import torch.fx
from torch._dynamo.utils import dynamo_timed from torch._dynamo.utils import dynamo_timed
from torch._logging import trace_structured from torch._logging import trace_structured
from torch.fx.node import _get_qualified_name
@dynamo_timed("vllm.generate_execution_code") def generate_execution_code_with_name(
def generate_execution_code(
split_gm: torch.fx.GraphModule, split_gm: torch.fx.GraphModule,
fn_name: str,
with_submod: bool,
) -> tuple[str, list[str]]: ) -> tuple[str, list[str]]:
"""Generate Python source code from a split_gm's stitching graph.
Walks split_gm.graph.nodes and produces a function that calls
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
and dict lookup cost.
Args:
split_gm: The split graph module produced by split_graph().
Returns:
A tuple of (code, submod_names) where code is the Python source
and submod_names is the ordered list of submodule target names
corresponding to list indices used in the generated code.
"""
lines: list[str] = [] lines: list[str] = []
param_names: list[str] = [] param_names: list[str] = []
submod_names: list[str] = [] submod_names: list[str] = []
...@@ -43,6 +31,7 @@ def generate_execution_code( ...@@ -43,6 +31,7 @@ def generate_execution_code(
# Build node ordering for liveness analysis. # Build node ordering for liveness analysis.
nodes = list(split_gm.graph.nodes) nodes = list(split_gm.graph.nodes)
node_order = {node: i for i, node in enumerate(nodes)} node_order = {node: i for i, node in enumerate(nodes)}
inlined_submods: list[str] = []
# For each value-producing node, find the position of its last consumer. # For each value-producing node, find the position of its last consumer.
# If the last consumer is the output node, skip (return handles cleanup). # If the last consumer is the output node, skip (return handles cleanup).
...@@ -65,6 +54,10 @@ def generate_execution_code( ...@@ -65,6 +54,10 @@ def generate_execution_code(
elif node.op == "call_module": elif node.op == "call_module":
target = node.target target = node.target
if not with_submod:
raise RuntimeError(
f"call_module is not allowed for codegen target {target}."
)
if target not in submod_index: if target not in submod_index:
submod_index[target] = len(submod_names) submod_index[target] = len(submod_names)
submod_names.append(target) submod_names.append(target)
...@@ -74,13 +67,32 @@ def generate_execution_code( ...@@ -74,13 +67,32 @@ def generate_execution_code(
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items() f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
) )
all_args = ", ".join(filter(None, [args_str, kwargs_str])) all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(f" {node.name} = __vllm_submods__[{idx}]({all_args})") submod = getattr(split_gm, target)
if isinstance(submod, torch.fx.GraphModule):
elif node.op == "call_function" and node.target is operator.getitem: callable_name = f"__vllm_inlined_submods__{idx}"
source = _node_ref(node.args[0]) inlined_code, _ = generate_execution_code_with_name(
index = node.args[1] submod, callable_name, with_submod=False
assert isinstance(index, int) )
lines.append(f" {node.name} = {source}[{index}]") inlined_submods.append(inlined_code)
else:
callable_name = f"__vllm_submods__[{idx}]"
lines.append(f" {node.name} = {callable_name}({all_args})")
elif node.op == "call_function":
if node.target is operator.getitem:
source = _node_ref(node.args[0])
index = node.args[1]
assert isinstance(index, int)
lines.append(f" {node.name} = {source}[{index}]")
else:
args_str = ", ".join(_node_ref(a) for a in node.args)
kwargs_str = ", ".join(
f"{k}={_node_ref(v)}" for k, v in node.kwargs.items()
)
all_args = ", ".join(filter(None, [args_str, kwargs_str]))
lines.append(
f" {node.name} = {_get_qualified_name(node.target)}({all_args})"
)
elif node.op == "output": elif node.op == "output":
assert len(node.args) == 1 assert len(node.args) == 1
...@@ -91,14 +103,44 @@ def generate_execution_code( ...@@ -91,14 +103,44 @@ def generate_execution_code(
raise RuntimeError(f"Unsupported node from codegen: {node.format_node()}") raise RuntimeError(f"Unsupported node from codegen: {node.format_node()}")
# Emit del for variables whose last use was this node. # Emit del for variables whose last use was this node.
if i in del_after: if i in del_after and i < len(nodes) - 2:
names = sorted(del_after[i]) names = sorted(del_after[i])
lines.append(f" del {', '.join(names)}") lines.append(f" del {', '.join(names)}")
assert len(param_names) > 0 assert len(param_names) > 0
params = ", ".join(param_names) params = ", ".join(param_names)
header = f"def execution_fn({params}, *, __vllm_submods__):" header = (
return "import torch\n" + "\n".join([header] + lines) + "\n", submod_names f"\ndef {fn_name}({params}{', *, __vllm_submods__' if with_submod else ''}):"
)
return "".join(inlined_submods) + "\n".join([header] + lines) + "\n", submod_names
@dynamo_timed("vllm.generate_execution_code")
def generate_execution_code(
split_gm: torch.fx.GraphModule,
) -> tuple[str, list[str]]:
"""Generate Python source code from a split_gm's stitching graph.
Walks split_gm.graph.nodes and produces a function that calls
submodules via a __vllm_submods__ list, avoiding FX GraphModule overhead
and dict lookup cost.
If a submodule is a plain torch.fx.GraphModule, it is inlined directly
in the generated code and we do not need to serialize it in the artifact.
Args:
split_gm: The split graph module produced by split_graph().
Returns:
A tuple of (code, submod_names) where code is the Python source
and submod_names is the ordered list of submodule target names
corresponding to list indices used in the generated code.
"""
code, submod_names = generate_execution_code_with_name(
split_gm, "execution_fn", with_submod=True
)
return "import torch\nimport operator\n" + code, submod_names
@dynamo_timed("vllm.compile_execution_fn") @dynamo_timed("vllm.compile_execution_fn")
...@@ -129,11 +171,12 @@ def compile_execution_fn( ...@@ -129,11 +171,12 @@ def compile_execution_fn(
namespace: dict[str, Any] = {} namespace: dict[str, Any] = {}
exec(code, namespace) # noqa: S102 exec(code, namespace) # noqa: S102
fn = namespace["execution_fn"] fn = namespace["execution_fn"]
# Use .forward() directly to avoid nn.Module.__call__ overhead. # Using .get() is intentional here because only piecewise backend will
submods_list = [ # be stored in submod_callables. The other submodules are inlined and
c.forward if isinstance(c, torch.fx.GraphModule) else c # we don't need to bind them to the execution function. Instead, we
for c in (submod_callables[name] for name in submod_names) # should use None as placeholder to ensure the list indices are preserved
] # for better debuggability.
submods_list = [submod_callables.get(name) for name in submod_names]
return partial(fn, __vllm_submods__=submods_list) return partial(fn, __vllm_submods__=submods_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment