Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
262d263f
Unverified
Commit
262d263f
authored
Nov 13, 2025
by
Yanan Cao
Committed by
GitHub
Nov 13, 2025
Browse files
[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533)
Signed-off-by:
Yanan Cao
<
gmagogsfm@gmail.com
>
parent
968060c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
2 deletions
+140
-2
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/compile/test_graph_partition.py
tests/compile/test_graph_partition.py
+124
-0
vllm/compilation/backends.py
vllm/compilation/backends.py
+15
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
262d263f
...
...
@@ -445,6 +445,7 @@ steps:
-
vllm/
-
tests/compile
commands
:
-
pytest -v -s compile/test_graph_partition.py
-
pytest -v -s compile/test_config.py
-
pytest -v -s compile/test_pass_manager.py
-
pytest -v -s compile/test_fusion.py
...
...
tests/compile/test_graph_partition.py
0 → 100644
View file @
262d263f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
operator
import
pytest
import
torch
from
torch.fx.experimental.proxy_tensor
import
make_fx
from
vllm.compilation.backends
import
split_graph
def
test_getitem_moved_to_producer_subgraph
():
"""
Test that getitem operations are moved to the same subgraph as their input,
preventing tuple inputs to submodules.
"""
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks
=
torch
.
split
(
x
,
x
.
shape
[
0
]
//
2
,
dim
=
0
)
# Following ops should become second submodule that consumes tuple
result_0
=
torch
.
relu
(
chunks
[
0
])
result_1
=
torch
.
relu
(
chunks
[
1
])
return
torch
.
cat
([
result_0
,
result_1
],
dim
=
0
)
x
=
torch
.
randn
(
4
,
3
)
gm
=
make_fx
(
model_fn
)(
x
)
has_getitem
=
any
(
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
for
node
in
gm
.
graph
.
nodes
)
assert
has_getitem
,
"Test setup failed: graph should contain getitem operations"
# Split on tuple producer aten::split
split_ops
=
[
"aten::split.Tensor"
]
split_gm
,
split_items
=
split_graph
(
gm
,
split_ops
)
assert
len
(
split_items
)
==
2
,
"Graph should be split into 2 submodules"
for
split_item
in
split_items
:
submodule
=
split_item
.
graph
getitem_on_placeholder
=
[]
for
node
in
submodule
.
graph
.
nodes
:
if
(
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
and
node
.
args
[
0
].
op
==
"placeholder"
):
getitem_on_placeholder
.
append
(
node
)
assert
len
(
getitem_on_placeholder
)
==
0
,
(
f
"Submodule
{
split_item
.
submod_name
}
has getitem operations on "
f
"placeholder nodes:
{
[
n
.
name
for
n
in
getitem_on_placeholder
]
}
. "
"This means tuple inputs were not properly eliminated."
)
new_x
=
torch
.
randn
(
4
,
3
)
output_original
=
gm
(
new_x
)
output_split
=
split_gm
(
new_x
)
assert
torch
.
allclose
(
output_original
,
output_split
),
"Output mismatch"
def
test_no_tuple_inputs_with_multiple_consumers
():
"""
Test that when a tuple is consumed by multiple split operations,
getitem operations are properly moved to avoid tuple inputs.
"""
def
model_fn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks
=
torch
.
split
(
x
,
x
.
shape
[
0
]
//
2
,
dim
=
0
)
# These should become second submodule consuming tuple
result_1
=
torch
.
relu
(
chunks
[
0
])
result_2
=
torch
.
relu
(
chunks
[
1
])
# Artificial graph splitting point to create another
# independent submodule that consumes tuple later
# This would become the third submodule
result_1
=
torch
.
sigmoid
(
result_1
)
# Fourth submodule that consumes tuple
result
=
torch
.
cat
([
chunks
[
0
],
chunks
[
1
],
result_1
,
result_2
])
return
result
x
=
torch
.
randn
(
4
,
3
)
gm
=
make_fx
(
model_fn
)(
x
)
has_getitem
=
any
(
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
for
node
in
gm
.
graph
.
nodes
)
assert
has_getitem
,
"Test setup failed: graph should contain getitem operations"
split_ops
=
[
"aten::split.Tensor"
,
"aten::sigmoid"
]
split_gm
,
split_items
=
split_graph
(
gm
,
split_ops
)
assert
len
(
split_items
)
==
4
,
"Graph should be split into 4 submodules"
for
split_item
in
split_items
:
submodule
=
split_item
.
graph
for
node
in
submodule
.
graph
.
nodes
:
if
(
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
and
node
.
args
[
0
].
op
==
"placeholder"
):
pytest
.
fail
(
f
"Submodule
{
split_item
.
submod_name
}
has getitem on "
f
"placeholder
{
node
.
args
[
0
].
name
}
, indicating it receives "
"a tuple input"
)
new_x
=
torch
.
randn
(
4
,
3
)
output_original
=
gm
(
new_x
)
output_split
=
split_gm
(
new_x
)
assert
torch
.
allclose
(
output_original
,
output_split
),
"Output mismatch after split"
vllm/compilation/backends.py
View file @
262d263f
...
...
@@ -4,6 +4,7 @@
import
ast
import
dataclasses
import
hashlib
import
operator
import
os
import
pprint
import
time
...
...
@@ -307,12 +308,24 @@ def split_graph(
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
# split graph by ops
subgraph_id
=
0
node_to_subgraph_id
=
{}
split_op_graphs
=
[]
node_to_subgraph_id
:
dict
[
fx
.
Node
,
int
]
=
{}
split_op_graphs
:
list
[
int
]
=
[]
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
in
(
"output"
,
"placeholder"
):
continue
# Check if this is a getitem operation on a node from an earlier subgraph.
# If so, assign it to the same subgraph as its input to avoid passing entire
# tuple as input to submodules, which is against standalone_compile and
# AoTAutograd input requirement.
if
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
:
# Assign this getitem to the same subgraph as its input
input_node
=
node
.
args
[
0
]
if
input_node
.
op
!=
"placeholder"
:
assert
input_node
in
node_to_subgraph_id
node_to_subgraph_id
[
node
]
=
node_to_subgraph_id
[
input_node
]
continue
if
should_split
(
node
,
splitting_ops
):
subgraph_id
+=
1
node_to_subgraph_id
[
node
]
=
subgraph_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment