Unverified Commit b158df28 authored by Boyuan Feng's avatar Boyuan Feng Committed by GitHub
Browse files

remove resolve_op_overloads and use splitting_ops directly (#28081)


Signed-off-by: default avatarBoyuan Feng <boyuan@meta.com>
parent 1aaecda0
......@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
def test_should_split():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
from vllm.compilation.partition_rules import should_split
graph = torch.fx.Graph()
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.aten.add.default,
args=(),
kwargs={},
)
# supports OpOverloadPacket
splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif(
......
......@@ -19,7 +19,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
inductor_partition_rule_context,
resolve_defined_ops,
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
......@@ -303,7 +303,7 @@ class SplitItem:
def split_graph(
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
......@@ -312,12 +312,8 @@ def split_graph(
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if node.op == "call_function" and (
node.target in resolved_ops
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
):
if should_split(node, splitting_ops):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
......@@ -653,8 +649,7 @@ class VllmBackend:
else:
fx_split_ops = self.compilation_config.splitting_ops or []
resolved_split_ops = resolve_defined_ops(fx_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
from torch._dynamo.utils import lazy_format_graph_code
......
......@@ -2,54 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
import torch
from torch._library.utils import lookup_op
from vllm.logger import init_logger
logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
"""Resolve operator names to OpOverload objects.
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
if node.op != "call_function":
return False
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
target = node.target
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
if isinstance(target, torch._ops.OpOverloadPacket):
# Example: "aten::add"
return target._qualified_op_name in splitting_ops
Returns:
List of successfully resolved operator overloads
"""
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from vllm.config import CompilationConfig
logger.log(
logging.DEBUG
if op_name in CompilationConfig._attention_ops
else logging.WARNING,
"Failed to resolve operator for CUDAGraph partition: %s",
op_name,
)
continue
return resolved
if isinstance(target, torch._ops.OpOverload):
# Example: "aten::add"
packet_name = target.name()
# Example: "aten::add.default"
op_overload_name = f"{packet_name}.{target._overloadname}"
return op_overload_name in splitting_ops or packet_name in splitting_ops
return False
@contextlib.contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment