Unverified Commit 608b5565 authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[ez] Add structured torch.compile logs (#33213)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
parent f0a1c845
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import pytest
import regex as re
import torch
from torch import nn
import tests.compile.silly_attention # noqa
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config.compilation import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
)
from vllm.config.scheduler import SchedulerConfig
from vllm.forward_context import set_forward_context
MLP_SIZE = 64
@support_torch_compile
class SimpleModel(nn.Module):
"""A simple model with a splitting op for piecewise compilation."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output * 2
return x
class TraceStructuredCapture:
"""Captures trace_structured calls for testing."""
def __init__(self):
self.calls: list[dict] = []
def __call__(self, event_type: str, metadata_fn=None, payload_fn=None, **kwargs):
"""Capture a trace_structured call."""
metadata = metadata_fn() if metadata_fn else {}
self.calls.append(
{
"event_type": event_type,
"metadata": metadata,
}
)
def get(self, event_type: str, name_pattern: str) -> list[dict]:
"""Get all calls with the given event type and name matching pattern.
Args:
event_type: The event type to filter by (e.g., "artifact", "graph_dump")
name_pattern: Regex pattern to match against the artifact name
"""
regex = re.compile(name_pattern)
return [
c
for c in self.calls
if c["event_type"] == event_type
and regex.fullmatch(c.get("metadata", {}).get("name", ""))
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
"""Test that all expected vLLM artifacts are logged during compilation."""
torch.set_default_device("cuda")
capture = TraceStructuredCapture()
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[8],
splitting_ops=["silly::attention"],
),
scheduler_config=SchedulerConfig(
max_num_seqs=8,
max_model_len=8192,
is_encoder_decoder=False,
),
)
# Patch trace_structured to capture calls
with (
patch("vllm.compilation.backends.trace_structured", capture),
patch("vllm.compilation.piecewise_backend.trace_structured", capture),
set_current_vllm_config(vllm_config),
):
model = SimpleModel(vllm_config=vllm_config, prefix="test")
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(8, MLP_SIZE))
config_artifacts = capture.get("artifact", "vllm_compilation_config")
assert len(config_artifacts) == 1, (
f"Expected 1 vllm_compilation_config, got {len(config_artifacts)}"
)
vllm_piecewise_split_graph = capture.get("graph_dump", "vllm_piecewise_split_graph")
assert len(vllm_piecewise_split_graph) == 1, (
"Expected 1 toplevel piecewise split graph, "
f"got {len(vllm_piecewise_split_graph)}"
)
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
assert len(compile_start_artifacts) == 2, (
"Expected 2 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), "
f"got {len(compile_start_artifacts)}"
)
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")
assert len(submod_dumps) == 2, (
"Expected 2 submods (one before attention, one after attention), "
f"got {len(submod_dumps)}"
)
...@@ -19,6 +19,7 @@ from typing import Any ...@@ -19,6 +19,7 @@ from typing import Any
import torch import torch
import torch.fx as fx import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._logging._internal import trace_structured
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context from vllm.compilation.inductor_pass import pass_context
...@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -529,6 +530,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
sym_shape_indices, sym_shape_indices,
self.vllm_backend, self.vllm_backend,
graph_returns_tuple(submod), graph_returns_tuple(submod),
submod_name=target,
) )
self.module.__dict__[target] = wrap_with_cudagraph_if_needed( self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
...@@ -735,12 +737,61 @@ class VllmBackend: ...@@ -735,12 +737,61 @@ class VllmBackend:
) )
self.inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def _log_compilation_config(self):
"""Log vLLM compilation config for TORCH_TRACE/tlparse."""
cc = self.compilation_config
pass_cfg = cc.pass_config
# Helper to convert lists to comma-separated strings for tlparse display
def list_to_str(lst: list | None) -> str:
if lst is None:
return ""
return ", ".join(str(x) for x in lst)
# Get enabled passes by introspecting dataclass fields
enabled_passes = [
f.name
for f in dataclasses.fields(pass_cfg)
if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
]
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_compilation_config",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"model": self.vllm_config.model_config.model,
"prefix": self.prefix,
"mode": str(cc.mode),
"backend": cc.backend,
"custom_ops": list_to_str(cc.custom_ops),
"splitting_ops": list_to_str(cc.splitting_ops),
"cudagraph_mode": str(cc.cudagraph_mode),
"compile_sizes": list_to_str(cc.compile_sizes),
"compile_ranges_split_points": list_to_str(
cc.compile_ranges_split_points
),
"use_inductor_graph_partition": cc.use_inductor_graph_partition,
"inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
"enabled_passes": list_to_str(enabled_passes),
"dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
"dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards, # noqa: E501
}
),
)
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import ( from .caching import (
VllmSerializableFunction, VllmSerializableFunction,
) )
vllm_config = self.vllm_config vllm_config = self.vllm_config
self._log_compilation_config()
# Minimal hashing here with existing utilities, reused below. # Minimal hashing here with existing utilities, reused below.
env_factors = envs.compile_factors() env_factors = envs.compile_factors()
...@@ -892,6 +943,13 @@ class VllmBackend: ...@@ -892,6 +943,13 @@ class VllmBackend:
lazy_format_graph_code("before split", self.graph) lazy_format_graph_code("before split", self.graph)
lazy_format_graph_code("after split", self.split_gm) lazy_format_graph_code("after split", self.split_gm)
# Log the piecewise split graph for TORCH_TRACE/tlparse
trace_structured(
"graph_dump",
metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
payload_fn=lambda: self.split_gm.print_readable(print_output=False),
)
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs) compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
submod_names_to_compile = [ submod_names_to_compile = [
item.submod_name item.submod_name
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import dataclasses import dataclasses
import io import io
import json
import pickle import pickle
from collections.abc import Callable from collections.abc import Callable
from pickle import Pickler from pickle import Pickler
...@@ -11,6 +12,7 @@ from typing import Any ...@@ -11,6 +12,7 @@ from typing import Any
import torch._functorch.config import torch._functorch.config
import torch.fx as fx import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
...@@ -39,6 +41,7 @@ class PiecewiseBackend: ...@@ -39,6 +41,7 @@ class PiecewiseBackend:
vllm_backend: VllmBackend, vllm_backend: VllmBackend,
returns_tuple: bool, returns_tuple: bool,
compiled_runnables: dict[str, Callable[..., Any]] | None = None, compiled_runnables: dict[str, Callable[..., Any]] | None = None,
submod_name: str = "",
): ):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
...@@ -70,6 +73,7 @@ class PiecewiseBackend: ...@@ -70,6 +73,7 @@ class PiecewiseBackend:
self.total_piecewise_compiles = total_piecewise_compiles self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables self.compiled_runnables = compiled_runnables
self.submod_name = submod_name
self.is_first_graph = piecewise_compile_index == 0 self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
...@@ -131,6 +135,9 @@ class PiecewiseBackend: ...@@ -131,6 +135,9 @@ class PiecewiseBackend:
compile_range=range, compile_range=range,
) )
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context... # get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call, # PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py) # which is when the context is set (see compilation/decorators.py)
...@@ -221,6 +228,45 @@ class PiecewiseBackend: ...@@ -221,6 +228,45 @@ class PiecewiseBackend:
assert len(fake_example_inputs) == len(args) assert len(fake_example_inputs) == len(args)
return fake_example_inputs return fake_example_inputs
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
is_cudagraph_size = (
self.compile_sizes is not None and compile_range.start in self.compile_sizes
)
subgraph_index = self.piecewise_compile_index
submod_name = self.submod_name
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "vllm_piecewise_compile_start",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
{
"piecewise_index": subgraph_index,
"submod_name": submod_name,
"total_piecewise_compiles": self.total_piecewise_compiles,
"compile_range_start": compile_range.start,
"compile_range_end": compile_range.end,
"is_single_size": compile_range.is_single_size(),
"is_cudagraph_capture_size": is_cudagraph_size,
}
),
)
# Log the subgraph graph dump only once per subgraph (not per size)
# to reduce log file size. The graph code is the same for all sizes.
if not self._graph_logged:
self._graph_logged = True
assert self.graph is not None
trace_structured(
"graph_dump",
metadata_fn=lambda: {
"name": f"vllm_{submod_name}",
},
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry( def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...] self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any: ) -> Any:
...@@ -230,6 +276,8 @@ class PiecewiseBackend: ...@@ -230,6 +276,8 @@ class PiecewiseBackend:
self.compiled_runnables[str(range_entry.compile_range)] self.compiled_runnables[str(range_entry.compile_range)]
) )
else: else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments # args are real arguments
# fakify for range, real args for concrete size. # fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in # For concrete size, we clear the shape env in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment