Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca5fb4bb
Unverified
Commit
ca5fb4bb
authored
Mar 10, 2026
by
Jiangyun Zhu
Committed by
GitHub
Mar 10, 2026
Browse files
[Bugfix] Avoid merging empty-only partitions into splitting-op subgraphs (#36595)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
cf88b237
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
29 deletions
+132
-29
tests/compile/test_graph_partition.py
tests/compile/test_graph_partition.py
+105
-15
vllm/compilation/backends.py
vllm/compilation/backends.py
+27
-14
No files found.
tests/compile/test_graph_partition.py
View file @
ca5fb4bb
...
...
@@ -7,7 +7,7 @@ import pytest
import
torch
from
torch.fx.experimental.proxy_tensor
import
make_fx
from
vllm.compilation.backends
import
split_graph
from
vllm.compilation.backends
import
_is_empty_allocation_node
,
split_graph
from
vllm.compilation.passes.fx_utils
import
find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
...
...
@@ -186,10 +186,25 @@ def test_consecutive_ops_in_split():
]
+
[
"output"
]
def
test_empty_only_partition_is_merged
():
def
_get_empty_nodes
(
split_item
):
return
[
node
for
node
in
split_item
.
graph
.
graph
.
nodes
if
_is_empty_allocation_node
(
node
)
]
def
_subgraphs_with_empty_nodes
(
split_items
,
*
,
is_splitting_graph
):
return
[
split_item
for
split_item
in
split_items
if
split_item
.
is_splitting_graph
==
is_splitting_graph
and
_get_empty_nodes
(
split_item
)
]
def
test_empty_only_partition_stays_separate_after_splitting_predecessor
():
"""
Test that an empty-allocation-only partition is merged into its previou
s
partition during Dynamo FX splitting
.
Empty-only subgraphs should not be merged when the only predecessor i
s
a splitting-op subgraph
.
"""
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -204,9 +219,65 @@ def test_empty_only_partition_is_merged():
split_ops
=
[
"aten::sin"
,
"aten::cos.out"
]
split_gm
,
split_items
=
split_graph
(
gm
,
split_ops
)
# Without the merge, this graph is split into 3 partitions where the
# middle partition contains only aten::empty_like.
assert
len
(
split_items
)
==
2
,
"Empty-only partition should be merged"
# Graph partitioning for this pattern is:
# [sin], [empty_like], [cos.out].
assert
len
(
split_items
)
==
3
,
(
"Empty-only partition should not merge into splitting-op subgraph"
)
splitting_with_empty
=
_subgraphs_with_empty_nodes
(
split_items
,
is_splitting_graph
=
True
)
assert
len
(
splitting_with_empty
)
==
0
,
(
"Splitting-op subgraphs should not contain empty allocation nodes: "
f
"
{
[
item
.
submod_name
for
item
in
splitting_with_empty
]
}
"
)
output_original
=
gm
(
x
)
output_split
=
split_gm
(
x
)
assert
torch
.
allclose
(
output_original
,
output_split
),
"Output mismatch after split"
def
test_empty_only_partition_is_merged
():
"""
Empty-only subgraphs should still be merged when a non-splitting predecessor
exists. The merged empty node must remain outside splitting-op subgraphs.
"""
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
base
=
x
+
1
y
=
torch
.
sin
(
base
)
out
=
torch
.
empty_like
(
base
)
torch
.
ops
.
aten
.
cos
.
out
(
base
,
out
=
out
)
return
out
+
y
x
=
torch
.
randn
(
4
,
3
)
gm
=
make_fx
(
model_fn
)(
x
)
split_gm
,
split_items
=
split_graph
(
gm
,
[
"aten::sin"
,
"aten::cos.out"
])
# Partitioning should be:
# [add, empty_like], [sin], [cos.out], [add].
assert
len
(
split_items
)
==
4
,
(
"Empty-only partition should be merged into non-splitting predecessor"
)
splitting_with_empty
=
_subgraphs_with_empty_nodes
(
split_items
,
is_splitting_graph
=
True
)
assert
len
(
splitting_with_empty
)
==
0
,
(
"Splitting-op subgraphs should not contain empty allocation nodes: "
f
"
{
[
item
.
submod_name
for
item
in
splitting_with_empty
]
}
"
)
non_splitting_with_empty
=
_subgraphs_with_empty_nodes
(
split_items
,
is_splitting_graph
=
False
)
assert
len
(
non_splitting_with_empty
)
==
1
,
(
"Exactly one non-splitting subgraph should contain the merged empty node"
)
assert
len
(
_get_empty_nodes
(
non_splitting_with_empty
[
0
]))
==
1
,
(
"Expected exactly one empty allocation node in merged subgraph"
)
output_original
=
gm
(
x
)
output_split
=
split_gm
(
x
)
...
...
@@ -220,18 +291,37 @@ def test_builtin_empty_only_partition_is_merged():
"""
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out1
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
out1
)
out2
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
out1
,
out1
,
out1
,
out2
)
return
out2
hidden
=
x
+
1
out1
=
torch
.
empty_like
(
hidden
)
torch
.
ops
.
silly
.
attention
(
hidden
,
hidden
,
hidden
,
out1
)
out2
=
torch
.
empty_like
(
hidden
)
torch
.
ops
.
silly
.
attention
(
out1
,
out1
,
hidden
,
out2
)
return
out2
+
hidden
gm
=
torch
.
fx
.
symbolic_trace
(
model_fn
)
split_gm
,
split_items
=
split_graph
(
gm
,
[
"silly::attention"
])
# Without the empty-only merge, this graph creates 4 partitions:
# [empty_like], [attention], [empty_like], [attention].
assert
len
(
split_items
)
==
3
,
"Builtin empty-only partition should be merged"
# Without empty-only merge, this graph would split into:
# [add, empty_like], [attention], [empty_like], [attention], [add].
assert
len
(
split_items
)
==
4
,
"Builtin empty-only partition should be merged"
splitting_with_empty
=
_subgraphs_with_empty_nodes
(
split_items
,
is_splitting_graph
=
True
)
assert
len
(
splitting_with_empty
)
==
0
,
(
"Splitting-op subgraphs should not contain empty allocation nodes: "
f
"
{
[
item
.
submod_name
for
item
in
splitting_with_empty
]
}
"
)
non_splitting_with_empty
=
_subgraphs_with_empty_nodes
(
split_items
,
is_splitting_graph
=
False
)
assert
len
(
non_splitting_with_empty
)
==
1
,
(
"Exactly one non-splitting subgraph should contain merged empty nodes"
)
assert
len
(
_get_empty_nodes
(
non_splitting_with_empty
[
0
]))
==
2
,
(
"Expected two builtin empty_like nodes in merged non-splitting subgraph"
)
x
=
torch
.
randn
(
2
,
3
,
device
=
"cuda"
)
output_original
=
gm
(
x
)
...
...
vllm/compilation/backends.py
View file @
ca5fb4bb
...
...
@@ -431,6 +431,7 @@ def _is_empty_allocation_node(node: fx.Node) -> bool:
def
_merge_empty_only_subgraphs
(
node_to_subgraph_id
:
dict
[
fx
.
Node
,
int
],
split_op_graphs
:
list
[
int
],
)
->
None
:
"""
Merge a partition that only contains an empty allocation op into the
...
...
@@ -439,23 +440,35 @@ def _merge_empty_only_subgraphs(
"""
nodes_by_subgraph_id
:
dict
[
int
,
list
[
fx
.
Node
]]
=
defaultdict
(
list
)
subgraph_id_order
:
list
[
int
]
=
[]
for
node
,
subgraph_id
in
node_to_subgraph_id
.
items
():
if
subgraph_id
not
in
nodes_by_subgraph_id
:
subgraph_id_order
.
append
(
subgraph_id
)
nodes_by_subgraph_id
[
subgraph_id
].
append
(
node
)
prev_subgraph_id
:
int
|
None
=
None
for
subgraph_id
in
subgraph_id_order
:
nodes
=
nodes_by_subgraph_id
[
subgraph_id
]
if
(
len
(
nodes
)
==
1
and
_is_empty_allocation_node
(
nodes
[
0
])
and
prev_subgraph_id
is
not
None
):
node_to_subgraph_id
[
nodes
[
0
]]
=
prev_subgraph_id
splitting_subgraphs
=
set
(
split_op_graphs
)
prev_non_splitting_subgraph_id
:
int
|
None
=
None
max_subgraph_id
=
max
(
node_to_subgraph_id
.
values
(),
default
=-
1
)
for
subgraph_id
in
range
(
max_subgraph_id
+
1
):
nodes
=
nodes_by_subgraph_id
.
get
(
subgraph_id
,
[])
if
not
nodes
:
continue
prev_subgraph_id
=
subgraph_id
is_non_splitting_subgraph
=
subgraph_id
not
in
splitting_subgraphs
is_empty_only_subgraph
=
len
(
nodes
)
==
1
and
_is_empty_allocation_node
(
nodes
[
0
])
merged
=
False
if
is_empty_only_subgraph
and
prev_non_splitting_subgraph_id
is
not
None
:
# Safety check: don't move allocation before any input producer.
empty_node
=
nodes
[
0
]
if
all
(
input_node
.
op
==
"placeholder"
or
node_to_subgraph_id
[
input_node
]
<=
prev_non_splitting_subgraph_id
for
input_node
in
empty_node
.
all_input_nodes
):
node_to_subgraph_id
[
empty_node
]
=
prev_non_splitting_subgraph_id
merged
=
True
if
not
merged
and
is_non_splitting_subgraph
:
prev_non_splitting_subgraph_id
=
subgraph_id
def
split_graph
(
...
...
@@ -496,7 +509,7 @@ def split_graph(
else
:
node_to_subgraph_id
[
node
]
=
subgraph_id
_merge_empty_only_subgraphs
(
node_to_subgraph_id
)
_merge_empty_only_subgraphs
(
node_to_subgraph_id
,
split_op_graphs
)
# `keep_original_order` is important!
# otherwise pytorch might reorder the nodes and
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment