Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f00c5539
Unverified
Commit
f00c5539
authored
Apr 12, 2026
by
Animesh Jain
Committed by
GitHub
Apr 12, 2026
Browse files
[compile] Bug fix for _decompose_size_nodes (#38360)
Signed-off-by:
Animesh Jain
<
anijain@umich.edu
>
parent
21fab0a3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
10 deletions
+99
-10
tests/compile/test_graph_partition.py
tests/compile/test_graph_partition.py
+75
-1
vllm/compilation/backends.py
vllm/compilation/backends.py
+24
-9
No files found.
tests/compile/test_graph_partition.py
View file @
f00c5539
...
...
@@ -9,7 +9,11 @@ import torch._dynamo
import
torch.fx
as
fx
from
torch.fx.experimental.proxy_tensor
import
make_fx
from
vllm.compilation.backends
import
_is_empty_allocation_node
,
split_graph
from
vllm.compilation.backends
import
(
_decompose_size_nodes
,
_is_empty_allocation_node
,
split_graph
,
)
from
vllm.compilation.passes.fx_utils
import
find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
...
...
@@ -622,3 +626,73 @@ def test_sym_size_metadata_propagated():
else
:
example_inputs
.
append
(
int
(
ev
))
standalone_compile
(
submod
,
example_inputs
,
dynamic_shapes
=
"from_example_inputs"
)
def
test_decompose_size_with_getitem_user
():
"""
Regression test: _decompose_size_nodes must handle getitem users of size()
correctly.
When a graph contains x.shape[i], it can appear as:
%size = call_method[target="size"](args = (%x,))
%getitem = call_function[target=operator.getitem](args = (%size, 1))
The old code spliced *all* per-dim values into every user's args
unconditionally, turning the 2-arg getitem into a malformed 3-arg node:
%getitem(args = (%sym_size_int, 5120, 1)) # TypeError at runtime
The fix detects getitem users and replaces them with dims[idx] directly.
"""
# Build a graph manually to guarantee the size() + getitem pattern.
#
# Graph:
# %x = placeholder
# %size = x.size()
# %dim1 = getitem(%size, 1) <-- the getitem branch we're testing
# %relu = relu(%x)
# %view = view(%relu, -1, %dim1)
# return %view
graph
=
fx
.
Graph
()
x
=
graph
.
placeholder
(
"x"
)
size_node
=
graph
.
call_method
(
"size"
,
args
=
(
x
,))
getitem_node
=
graph
.
call_function
(
operator
.
getitem
,
args
=
(
size_node
,
1
))
relu_node
=
graph
.
call_function
(
torch
.
ops
.
aten
.
relu
.
default
,
args
=
(
x
,))
view_node
=
graph
.
call_function
(
torch
.
ops
.
aten
.
view
.
default
,
args
=
(
relu_node
,
[
-
1
,
getitem_node
])
)
graph
.
output
(
view_node
)
# Attach example_value metadata so _decompose_size_nodes can inspect dims.
# dim 0 is dynamic (SymInt), dim 1 is static (8).
from
torch._dynamo.source
import
LocalSource
from
torch._subclasses.fake_tensor
import
FakeTensorMode
from
torch.fx.experimental.symbolic_shapes
import
ShapeEnv
shape_env
=
ShapeEnv
()
src
=
LocalSource
(
"batch_size"
)
sym_batch
=
shape_env
.
create_symintnode
(
shape_env
.
create_symbol
(
4
,
src
),
hint
=
4
)
fake_mode
=
FakeTensorMode
(
shape_env
=
shape_env
)
with
fake_mode
:
fake_x
=
torch
.
empty_strided
((
sym_batch
,
8
),
(
8
,
1
))
x
.
meta
[
"example_value"
]
=
fake_x
gm
=
fx
.
GraphModule
(
torch
.
nn
.
Module
(),
graph
)
# Run decomposition — this would produce a 3-arg getitem without the fix
_decompose_size_nodes
(
gm
)
# Verify no size() nodes remain
remaining_size_nodes
=
list
(
gm
.
graph
.
find_nodes
(
op
=
"call_method"
,
target
=
"size"
))
assert
len
(
remaining_size_nodes
)
==
0
,
(
f
"size() nodes should be fully decomposed, found
{
len
(
remaining_size_nodes
)
}
"
)
# Verify no malformed getitem nodes (3+ args)
for
node
in
gm
.
graph
.
nodes
:
if
node
.
op
==
"call_function"
and
node
.
target
is
operator
.
getitem
:
assert
len
(
node
.
args
)
==
2
,
(
f
"getitem node '
{
node
.
name
}
' has
{
len
(
node
.
args
)
}
args "
f
"(expected 2):
{
node
.
args
}
"
)
vllm/compilation/backends.py
View file @
f00c5539
...
...
@@ -516,9 +516,24 @@ def _decompose_size_nodes(graph: fx.GraphModule) -> None:
)
# Replace size node in each user's args.
# Dynamo always passes size as a direct arg: view(clone, size)
# → view(clone, d0, d1, ...)
for
user
in
list
(
node
.
users
):
if
(
user
.
op
==
"call_function"
and
user
.
target
is
operator
.
getitem
and
len
(
user
.
args
)
==
2
and
user
.
args
[
0
]
is
node
):
# getitem(size, idx) → replace with dims[idx] directly.
idx
=
user
.
args
[
1
]
assert
isinstance
(
idx
,
int
),
(
f
"Expected literal int index for getitem on size(), "
f
"got
{
type
(
idx
).
__name__
}
:
{
idx
}
"
)
user
.
replace_all_uses_with
(
dims
[
idx
])
graph
.
graph
.
erase_node
(
user
)
else
:
# User consumes the full size tuple (e.g. view(clone, size))
# → view(clone, d0, d1, ...)
new_args
=
[]
for
arg
in
user
.
args
:
if
arg
is
node
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment