Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea2c148f
Unverified
Commit
ea2c148f
authored
Mar 19, 2026
by
Xiao
Committed by
GitHub
Mar 19, 2026
Browse files
[compile][graph_partition]Add tensor size handling (#36038)
Signed-off-by:
Xiao Fu
<
xiaofu@meta.com
>
parent
47b7af0d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
351 additions
and
0 deletions
+351
-0
tests/compile/test_graph_partition.py
tests/compile/test_graph_partition.py
+295
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+56
-0
No files found.
tests/compile/test_graph_partition.py
View file @
ea2c148f
...
...
@@ -5,6 +5,8 @@ import operator
import
pytest
import
torch
import
torch._dynamo
import
torch.fx
as
fx
from
torch.fx.experimental.proxy_tensor
import
make_fx
from
vllm.compilation.backends
import
_is_empty_allocation_node
,
split_graph
...
...
@@ -327,3 +329,296 @@ def test_builtin_empty_only_partition_is_merged():
output_original
=
gm
(
x
)
output_split
=
split_gm
(
x
)
assert
torch
.
allclose
(
output_original
,
output_split
),
"Output mismatch after split"
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"requires CUDA"
)
def
test_sym_size_whole_shape_boundary
():
"""
Test that using x.size() (whole shape) across a split boundary can be
compiled by standalone_compile.
The dynamo graph looks like:
shape = x.size()
y = sigmoid(x) # split point
z = y.clone().view(shape)
Which splits into:
subgraph0(x) -> shape # returns torch.Size — problematic
subgraph1(x) -> y # sigmoid
subgraph2(y, shape) -> z # view
Two approaches to fix the torch.Size crossing:
Approach 1 — move sym_size to consumer (memory implication: x passed to
subgraph2 just for .size()):
subgraph0(x) -> # empty
subgraph1(x) -> y
subgraph2(y, x) -> z # computes shape locally from x
Approach 2 — decompose shape into individual int/SymInt values:
subgraph0(x) -> s0, val # returns individual scalars, not Size
subgraph1(x) -> y
subgraph2(y, s0, val) -> z # reconstructs view args from scalars
"""
from
torch._inductor
import
standalone_compile
captured_graph
=
None
def
capturing_backend
(
gm
:
fx
.
GraphModule
,
example_inputs
:
list
)
->
fx
.
GraphModule
:
nonlocal
captured_graph
captured_graph
=
gm
return
gm
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shape
=
x
.
size
()
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
shape
)
return
x
x
=
torch
.
randn
(
4
,
8
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
compiled_fn
=
torch
.
compile
(
model_fn
,
backend
=
capturing_backend
)
compiled_fn
(
x
)
split_gm
,
split_items
=
split_graph
(
captured_graph
,
[
"aten::sigmoid"
])
assert
len
(
split_items
)
==
3
submod_0
=
split_gm
.
submod_0
example_input
=
torch
.
randn
(
4
,
8
)
compiled
=
standalone_compile
(
submod_0
,
[
example_input
,
4
],
dynamic_shapes
=
"from_example_inputs"
)
assert
compiled
is
not
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"requires CUDA"
)
def
test_symint_crosses_split_boundary
():
"""
Test that SymInt placeholders from torch.compile + mark_dynamic
cross split boundaries safely via split_module's natural threading.
SymInt values are threaded through subgraphs by split_module and
handled correctly by inductor — no special replacement is needed.
"""
captured_graph
=
None
def
capturing_backend
(
gm
:
fx
.
GraphModule
,
example_inputs
:
list
)
->
fx
.
GraphModule
:
nonlocal
captured_graph
captured_graph
=
gm
return
gm
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
x
.
shape
[
0
]
hidden_size
=
x
.
shape
[
1
]
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
batch_size
,
hidden_size
)
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
batch_size
,
hidden_size
)
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
batch_size
,
hidden_size
)
return
x
x
=
torch
.
randn
(
4
,
8
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
compiled_fn
=
torch
.
compile
(
model_fn
,
backend
=
capturing_backend
)
compiled_fn
(
x
)
assert
captured_graph
is
not
None
,
"Graph should be captured by backend"
# SymInt placeholders should exist in the captured graph
symint_placeholders
=
[
node
for
node
in
captured_graph
.
graph
.
nodes
if
node
.
op
==
"placeholder"
and
isinstance
(
node
.
meta
.
get
(
"example_value"
),
torch
.
SymInt
)
]
assert
len
(
symint_placeholders
)
>
0
,
(
"Captured graph should have SymInt placeholders from mark_dynamic."
)
# split_graph should handle SymInt placeholders without error
split_gm
,
split_items
=
split_graph
(
captured_graph
,
[
"aten::sigmoid"
])
# Should have 3 splitting subgraphs (3 sigmoids)
splitting_subgraphs
=
[
item
for
item
in
split_items
if
item
.
is_splitting_graph
]
assert
len
(
splitting_subgraphs
)
==
3
,
(
f
"Expected 3 splitting subgraphs (3 sigmoids), got
{
len
(
splitting_subgraphs
)
}
"
)
assert
len
(
split_items
)
>=
6
,
(
f
"Expected at least 6 total subgraphs, got
{
len
(
split_items
)
}
"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"requires CUDA"
)
def
test_shape_boundary_standalone_compile
():
"""
Repro for the original production bug:
AssertionError: out_spec mismatch
TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *])
vs
TreeSpec(tuple, None, [*, *, *, *])
A subgraph outputs torch.Size (e.g. torch.Size([s72, 2048])) as one of
its values when shape info crosses a split boundary. aot_autograd / inductor
expect all submodule outputs to be flat tensors or scalars, not torch.Size.
With the fix, x.size() is decomposed into individual sym_size.int calls
so only scalar SymInts cross the boundary — not the torch.Size.
"""
from
torch._inductor
import
standalone_compile
captured_graph
=
None
def
capturing_backend
(
gm
:
fx
.
GraphModule
,
example_inputs
:
list
)
->
fx
.
GraphModule
:
nonlocal
captured_graph
captured_graph
=
gm
return
gm
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shape
=
x
.
size
()
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
shape
)
return
x
x
=
torch
.
randn
(
4
,
8
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
torch
.
compile
(
model_fn
,
backend
=
capturing_backend
)(
x
)
split_gm
,
split_items
=
split_graph
(
captured_graph
,
[
"aten::sigmoid"
])
assert
len
(
split_items
)
==
3
# Verify that the consumer subgraph only has a placeholder for the dynamic
# dim (SymInt) — the static dim (8) should be inlined as a literal, not
# threaded as a placeholder.
consumer
=
split_items
[
-
1
]
# valid since len == 3: [producer, sigmoid, consumer]
symint_placeholders
=
[
n
for
n
in
consumer
.
graph
.
graph
.
nodes
if
n
.
op
==
"placeholder"
and
isinstance
(
n
.
meta
.
get
(
"example_value"
),
torch
.
SymInt
)
]
static_int_placeholders
=
[
n
for
n
in
consumer
.
graph
.
graph
.
nodes
if
n
.
op
==
"placeholder"
and
isinstance
(
n
.
meta
.
get
(
"example_value"
),
int
)
and
not
isinstance
(
n
.
meta
.
get
(
"example_value"
),
torch
.
SymInt
)
]
assert
len
(
symint_placeholders
)
>=
1
,
(
"Consumer should have a SymInt placeholder for the dynamic dim."
)
assert
len
(
static_int_placeholders
)
==
0
,
(
"Static dims should be inlined as literals, not threaded as placeholders."
)
submod_0
=
split_gm
.
submod_0
standalone_compile
(
submod_0
,
[
torch
.
randn
(
4
,
8
),
4
],
dynamic_shapes
=
"from_example_inputs"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"requires CUDA"
)
def
test_size_used_in_multiple_consumer_subgraphs
():
"""
Validates that x.size() (whole shape) used by multiple downstream subgraphs
does not cause torch.Size to cross split boundaries.
Model:
shape = x.size() # whole shape — must not cross as torch.Size
z1 = sigmoid(x) # split point 1
y1 = y.view(shape) # consumer 1 uses shape
z2 = sigmoid(z1) # split point 2
y2 = y.view(shape) # consumer 2 uses shape again
Without the fix, torch.Size crosses the boundary as a submodule output,
which aot_autograd / standalone_compile rejects.
"""
captured_graph
=
None
captured_inputs
=
None
def
capturing_backend
(
gm
:
fx
.
GraphModule
,
example_inputs
:
list
)
->
fx
.
GraphModule
:
nonlocal
captured_graph
,
captured_inputs
captured_graph
=
gm
captured_inputs
=
example_inputs
return
gm
def
model_fn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shape
=
x
.
size
()
z1
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
y1
=
y
.
view
(
shape
)
z2
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
z1
)
y2
=
y
.
view
(
shape
)
return
z2
+
y1
+
y2
x
=
torch
.
randn
(
4
,
8
)
y
=
torch
.
randn
(
4
,
8
)
# same shape as x so view(shape) doesn't specialize dim 0
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
y
,
0
)
torch
.
compile
(
model_fn
,
backend
=
capturing_backend
)(
x
,
y
)
split_gm
,
split_items
=
split_graph
(
captured_graph
,
[
"aten::sigmoid"
])
splitting_items
=
[
item
for
item
in
split_items
if
item
.
is_splitting_graph
]
assert
len
(
splitting_items
)
==
2
# Verify functional correctness — fails without the fix because torch.Size
# would cross a split boundary as a submodule output
output_original
=
model_fn
(
x
,
y
)
output_split
=
split_gm
(
*
captured_inputs
)
if
isinstance
(
output_split
,
tuple
):
output_split
=
next
(
o
for
o
in
output_split
if
isinstance
(
o
,
torch
.
Tensor
))
assert
torch
.
allclose
(
output_original
,
output_split
),
"Output mismatch after split"
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"requires CUDA"
)
def
test_sym_size_metadata_propagated
():
"""
Validates that new sym_size.int nodes created by the pre-pass have
example_value metadata set. Without it, placeholder metadata in consumer
subgraphs would be None, breaking any code that dynamically builds
example inputs from metadata (e.g. standalone_compile per-submodule).
"""
from
torch._inductor
import
standalone_compile
captured_graph
=
None
def
capturing_backend
(
gm
:
fx
.
GraphModule
,
example_inputs
:
list
)
->
fx
.
GraphModule
:
nonlocal
captured_graph
captured_graph
=
gm
return
gm
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shape
=
x
.
size
()
x
=
torch
.
ops
.
aten
.
sigmoid
.
default
(
x
)
x
=
x
.
clone
().
view
(
shape
)
return
x
x
=
torch
.
randn
(
4
,
8
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
torch
.
compile
(
model_fn
,
backend
=
capturing_backend
)(
x
)
split_gm
,
split_items
=
split_graph
(
captured_graph
,
[
"aten::sigmoid"
])
# For each submodule, build example inputs purely from placeholder metadata.
# This fails if example_value is None on any placeholder (i.e. metadata
# was not propagated to the sym_size.int nodes we created).
for
item
in
split_items
:
submod
=
item
.
graph
example_inputs
=
[]
for
n
in
submod
.
graph
.
nodes
:
if
n
.
op
!=
"placeholder"
:
continue
ev
=
n
.
meta
.
get
(
"example_value"
)
assert
ev
is
not
None
,
(
f
"Placeholder '
{
n
.
name
}
' in
{
item
.
submod_name
}
has no "
"example_value metadata. sym_size.int nodes must propagate "
"metadata so consumer subgraphs can be introspected."
)
if
isinstance
(
ev
,
torch
.
Tensor
):
example_inputs
.
append
(
torch
.
randn
(
*
(
int
(
d
)
for
d
in
ev
.
shape
)))
else
:
example_inputs
.
append
(
int
(
ev
))
standalone_compile
(
submod
,
example_inputs
,
dynamic_shapes
=
"from_example_inputs"
)
vllm/compilation/backends.py
View file @
ea2c148f
...
...
@@ -473,9 +473,65 @@ def _merge_empty_only_subgraphs(
prev_non_splitting_subgraph_id
=
subgraph_id
def
_decompose_size_nodes
(
graph
:
fx
.
GraphModule
)
->
None
:
"""Decompose x.size() into per-dim sym_size.int calls.
torch.Size objects cannot cross split boundaries because aot_autograd
cannot handle them as submodule outputs. This replaces each size() call
with individual sym_size.int(x, dim) nodes:
- Dynamic dims (SymInt) → new sym_size.int node
- Static dims (plain int) → inlined as literal constant
"""
# Dynamo captures x.size()/x.shape as call_method target="size".
size_nodes
=
list
(
graph
.
graph
.
find_nodes
(
op
=
"call_method"
,
target
=
"size"
))
for
node
in
size_nodes
:
tensor_node
=
node
.
args
[
0
]
ev
=
tensor_node
.
meta
.
get
(
"example_value"
)
assert
ev
is
not
None
,
(
f
"Tensor node '
{
tensor_node
.
name
}
' has no example_value metadata. "
f
"Cannot decompose size node '
{
node
.
name
}
'."
)
# Build per-dim replacements: sym_size.int node or literal int.
dims
:
list
[
fx
.
Node
|
int
]
=
[]
with
graph
.
graph
.
inserting_after
(
tensor_node
):
for
i
in
range
(
ev
.
dim
()):
dim_val
=
ev
.
shape
[
i
]
if
isinstance
(
dim_val
,
torch
.
SymInt
):
dn
=
graph
.
graph
.
call_function
(
torch
.
ops
.
aten
.
sym_size
.
int
,
args
=
(
tensor_node
,
i
)
)
dn
.
meta
[
"example_value"
]
=
dim_val
dims
.
append
(
dn
)
elif
isinstance
(
dim_val
,
int
):
dims
.
append
(
dim_val
)
else
:
raise
AssertionError
(
f
"dim_val is either torch.SymInt or int, "
f
"got
{
type
(
dim_val
)
}
for dim
{
i
}
of "
f
"'
{
node
.
name
}
'"
)
# Replace size node in each user's args.
# Dynamo always passes size as a direct arg: view(clone, size)
# → view(clone, d0, d1, ...)
for
user
in
list
(
node
.
users
):
new_args
=
[]
for
arg
in
user
.
args
:
if
arg
is
node
:
new_args
.
extend
(
dims
)
else
:
new_args
.
append
(
arg
)
user
.
args
=
tuple
(
new_args
)
graph
.
graph
.
erase_node
(
node
)
def
split_graph
(
graph
:
fx
.
GraphModule
,
splitting_ops
:
list
[
str
]
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
_decompose_size_nodes
(
graph
)
# split graph by ops
subgraph_id
=
0
node_to_subgraph_id
:
dict
[
fx
.
Node
,
int
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment