Unverified Commit f00c5539 authored by Animesh Jain's avatar Animesh Jain Committed by GitHub
Browse files

[compile] Bug fix for _decompose_size_nodes (#38360)


Signed-off-by: default avatarAnimesh Jain <anijain@umich.edu>
parent 21fab0a3
......@@ -9,7 +9,11 @@ import torch._dynamo
import torch.fx as fx
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
from vllm.compilation.backends import (
_decompose_size_nodes,
_is_empty_allocation_node,
split_graph,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
......@@ -622,3 +626,73 @@ def test_sym_size_metadata_propagated():
else:
example_inputs.append(int(ev))
standalone_compile(submod, example_inputs, dynamic_shapes="from_example_inputs")
def test_decompose_size_with_getitem_user():
"""
Regression test: _decompose_size_nodes must handle getitem users of size()
correctly.
When a graph contains x.shape[i], it can appear as:
%size = call_method[target="size"](args = (%x,))
%getitem = call_function[target=operator.getitem](args = (%size, 1))
The old code spliced *all* per-dim values into every user's args
unconditionally, turning the 2-arg getitem into a malformed 3-arg node:
%getitem(args = (%sym_size_int, 5120, 1)) # TypeError at runtime
The fix detects getitem users and replaces them with dims[idx] directly.
"""
# Build a graph manually to guarantee the size() + getitem pattern.
#
# Graph:
# %x = placeholder
# %size = x.size()
# %dim1 = getitem(%size, 1) <-- the getitem branch we're testing
# %relu = relu(%x)
# %view = view(%relu, -1, %dim1)
# return %view
graph = fx.Graph()
x = graph.placeholder("x")
size_node = graph.call_method("size", args=(x,))
getitem_node = graph.call_function(operator.getitem, args=(size_node, 1))
relu_node = graph.call_function(torch.ops.aten.relu.default, args=(x,))
view_node = graph.call_function(
torch.ops.aten.view.default, args=(relu_node, [-1, getitem_node])
)
graph.output(view_node)
# Attach example_value metadata so _decompose_size_nodes can inspect dims.
# dim 0 is dynamic (SymInt), dim 1 is static (8).
from torch._dynamo.source import LocalSource
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
shape_env = ShapeEnv()
src = LocalSource("batch_size")
sym_batch = shape_env.create_symintnode(shape_env.create_symbol(4, src), hint=4)
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
fake_x = torch.empty_strided((sym_batch, 8), (8, 1))
x.meta["example_value"] = fake_x
gm = fx.GraphModule(torch.nn.Module(), graph)
# Run decomposition — this would produce a 3-arg getitem without the fix
_decompose_size_nodes(gm)
# Verify no size() nodes remain
remaining_size_nodes = list(gm.graph.find_nodes(op="call_method", target="size"))
assert len(remaining_size_nodes) == 0, (
f"size() nodes should be fully decomposed, found {len(remaining_size_nodes)}"
)
# Verify no malformed getitem nodes (3+ args)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target is operator.getitem:
assert len(node.args) == 2, (
f"getitem node '{node.name}' has {len(node.args)} args "
f"(expected 2): {node.args}"
)
......@@ -516,9 +516,24 @@ def _decompose_size_nodes(graph: fx.GraphModule) -> None:
)
# Replace size node in each user's args.
# Dynamo always passes size as a direct arg: view(clone, size)
# → view(clone, d0, d1, ...)
for user in list(node.users):
if (
user.op == "call_function"
and user.target is operator.getitem
and len(user.args) == 2
and user.args[0] is node
):
# getitem(size, idx) → replace with dims[idx] directly.
idx = user.args[1]
assert isinstance(idx, int), (
f"Expected literal int index for getitem on size(), "
f"got {type(idx).__name__}: {idx}"
)
user.replace_all_uses_with(dims[idx])
graph.graph.erase_node(user)
else:
# User consumes the full size tuple (e.g. view(clone, size))
# → view(clone, d0, d1, ...)
new_args = []
for arg in user.args:
if arg is node:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment