Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b158df28
Unverified
Commit
b158df28
authored
Nov 07, 2025
by
Boyuan Feng
Committed by
GitHub
Nov 08, 2025
Browse files
remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by:
Boyuan Feng
<
boyuan@meta.com
>
parent
1aaecda0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
65 deletions
+89
-65
tests/compile/test_config.py
tests/compile/test_config.py
+63
-19
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-10
vllm/compilation/partition_rules.py
vllm/compilation/partition_rules.py
+21
-36
No files found.
tests/compile/test_config.py
View file @
b158df28
...
...
@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
def
test_
resolve_operator_overload
():
def
test_
should_split
():
import
torch
from
vllm.compilation.partition_rules
import
resolve_defined_ops
from
vllm.compilation.partition_rules
import
should_split
# Test valid operator names
resolved
=
resolve_defined_ops
([
"aten::mm.default"
,
"aten::addmm.default"
])
assert
len
(
resolved
)
==
2
assert
resolved
[
0
]
is
torch
.
ops
.
aten
.
mm
.
default
assert
resolved
[
1
]
is
torch
.
ops
.
aten
.
addmm
.
default
graph
=
torch
.
fx
.
Graph
()
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
aten
.
add
.
default
,
args
=
(),
kwargs
=
{},
)
# Test that invalid operators are skipped (not raising exceptions)
resolved
=
resolve_defined_ops
(
[
"aten::mm.default"
,
"aten::nonexistent_op.default"
,
# This should be skipped
"aten::addmm.default"
,
]
# supports OpOverloadPacket
splitting_ops
=
[
"aten::add"
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.default"
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.Tensor"
]
assert
not
should_split
(
node
,
splitting_ops
)
@
torch
.
library
.
custom_op
(
"silly::attention"
,
mutates_args
=
[
"out"
],
)
assert
len
(
resolved
)
==
2
# Only 2 valid ops
assert
resolved
[
0
]
is
torch
.
ops
.
aten
.
mm
.
default
assert
resolved
[
1
]
is
torch
.
ops
.
aten
.
addmm
.
default
def
attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
q
+
k
+
v
)
q
,
k
,
v
,
out
=
[
torch
.
randn
(
1
)]
*
4
# supports custom ops as OpOverloadPacket
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
# supports custom ops as OpOverload
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
.
default
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
splitting_ops
=
[
"silly::attention.default"
]
assert
should_split
(
node
,
splitting_ops
)
@
pytest
.
mark
.
skipif
(
...
...
vllm/compilation/backends.py
View file @
b158df28
...
...
@@ -19,7 +19,7 @@ import vllm.envs as envs
from
vllm.compilation.inductor_pass
import
pass_context
from
vllm.compilation.partition_rules
import
(
inductor_partition_rule_context
,
resolve_defined_ops
,
should_split
,
)
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.logger
import
init_logger
...
...
@@ -303,7 +303,7 @@ class SplitItem:
def
split_graph
(
graph
:
fx
.
GraphModule
,
resolved
_ops
:
list
[
torch
.
_ops
.
OpOverload
]
graph
:
fx
.
GraphModule
,
splitting
_ops
:
list
[
str
]
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
# split graph by ops
subgraph_id
=
0
...
...
@@ -312,12 +312,8 @@ def split_graph(
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
in
(
"output"
,
"placeholder"
):
continue
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if
node
.
op
==
"call_function"
and
(
node
.
target
in
resolved_ops
or
(
hasattr
(
node
.
target
,
"default"
)
and
node
.
target
.
default
in
resolved_ops
)
):
if
should_split
(
node
,
splitting_ops
):
subgraph_id
+=
1
node_to_subgraph_id
[
node
]
=
subgraph_id
split_op_graphs
.
append
(
subgraph_id
)
...
...
@@ -653,8 +649,7 @@ class VllmBackend:
else
:
fx_split_ops
=
self
.
compilation_config
.
splitting_ops
or
[]
resolved_split_ops
=
resolve_defined_ops
(
fx_split_ops
)
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
resolved_split_ops
)
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
fx_split_ops
)
from
torch._dynamo.utils
import
lazy_format_graph_code
...
...
vllm/compilation/partition_rules.py
View file @
b158df28
...
...
@@ -2,54 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
logging
import
torch
from
torch._library.utils
import
lookup_op
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
resolve_defined_ops
(
op_names
:
list
[
str
])
->
list
[
"torch._ops.OpOverload"
]:
"""Resolve operator names to OpOverload objects.
def
should_split
(
node
:
torch
.
fx
.
Node
,
splitting_ops
:
list
[
str
])
->
bool
:
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
if
node
.
op
!=
"call_function"
:
return
False
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
target
=
node
.
target
Args
:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
if
isinstance
(
target
,
torch
.
_ops
.
OpOverloadPacket
)
:
# Example: "aten::add"
return
target
.
_qualified_op_name
in
splitting_ops
Returns:
List of successfully resolved operator overloads
"""
resolved
=
[]
for
op_name
in
op_names
:
try
:
resolved
.
append
(
lookup_op
(
op_name
))
except
Exception
:
# Skip operators that don't exist (e.g., model-specific ops)
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from
vllm.config
import
CompilationConfig
logger
.
log
(
logging
.
DEBUG
if
op_name
in
CompilationConfig
.
_attention_ops
else
logging
.
WARNING
,
"Failed to resolve operator for CUDAGraph partition: %s"
,
op_name
,
)
continue
if
isinstance
(
target
,
torch
.
_ops
.
OpOverload
):
# Example: "aten::add"
packet_name
=
target
.
name
()
# Example: "aten::add.default"
op_overload_name
=
f
"
{
packet_name
}
.
{
target
.
_overloadname
}
"
return
op_overload_name
in
splitting_ops
or
packet_name
in
splitting_ops
return
resolved
return
False
@
contextlib
.
contextmanager
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment