Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b158df28
Unverified
Commit
b158df28
authored
Nov 07, 2025
by
Boyuan Feng
Committed by
GitHub
Nov 08, 2025
Browse files
remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by:
Boyuan Feng
<
boyuan@meta.com
>
parent
1aaecda0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
65 deletions
+89
-65
tests/compile/test_config.py
tests/compile/test_config.py
+63
-19
vllm/compilation/backends.py
vllm/compilation/backends.py
+5
-10
vllm/compilation/partition_rules.py
vllm/compilation/partition_rules.py
+21
-36
No files found.
tests/compile/test_config.py
View file @
b158df28
...
@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
...
@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
assert
config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
def
test_
resolve_operator_overload
():
def
test_
should_split
():
import
torch
import
torch
from
vllm.compilation.partition_rules
import
resolve_defined_ops
from
vllm.compilation.partition_rules
import
should_split
# Test valid operator names
graph
=
torch
.
fx
.
Graph
()
resolved
=
resolve_defined_ops
([
"aten::mm.default"
,
"aten::addmm.default"
])
node
=
torch
.
fx
.
Node
(
assert
len
(
resolved
)
==
2
graph
=
graph
,
assert
resolved
[
0
]
is
torch
.
ops
.
aten
.
mm
.
default
name
=
"dummy_node"
,
assert
resolved
[
1
]
is
torch
.
ops
.
aten
.
addmm
.
default
op
=
"call_function"
,
target
=
torch
.
ops
.
aten
.
add
.
default
,
# Test that invalid operators are skipped (not raising exceptions)
args
=
(),
resolved
=
resolve_defined_ops
(
kwargs
=
{},
[
)
"aten::mm.default"
,
"aten::nonexistent_op.default"
,
# This should be skipped
# supports OpOverloadPacket
"aten::addmm.default"
,
splitting_ops
=
[
"aten::add"
]
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.default"
]
assert
should_split
(
node
,
splitting_ops
)
# supports OpOverload
splitting_ops
=
[
"aten::add.Tensor"
]
assert
not
should_split
(
node
,
splitting_ops
)
@
torch
.
library
.
custom_op
(
"silly::attention"
,
mutates_args
=
[
"out"
],
)
)
assert
len
(
resolved
)
==
2
# Only 2 valid ops
def
attention
(
assert
resolved
[
0
]
is
torch
.
ops
.
aten
.
mm
.
default
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
assert
resolved
[
1
]
is
torch
.
ops
.
aten
.
addmm
.
default
)
->
None
:
out
.
copy_
(
q
+
k
+
v
)
q
,
k
,
v
,
out
=
[
torch
.
randn
(
1
)]
*
4
# supports custom ops as OpOverloadPacket
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
# supports custom ops as OpOverload
node
=
torch
.
fx
.
Node
(
graph
=
graph
,
name
=
"dummy_node"
,
op
=
"call_function"
,
target
=
torch
.
ops
.
silly
.
attention
.
default
,
args
=
(
q
,
k
,
v
,
out
),
kwargs
=
{},
)
splitting_ops
=
[
"silly::attention"
]
assert
should_split
(
node
,
splitting_ops
)
splitting_ops
=
[
"silly::attention.default"
]
assert
should_split
(
node
,
splitting_ops
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
...
vllm/compilation/backends.py
View file @
b158df28
...
@@ -19,7 +19,7 @@ import vllm.envs as envs
...
@@ -19,7 +19,7 @@ import vllm.envs as envs
from
vllm.compilation.inductor_pass
import
pass_context
from
vllm.compilation.inductor_pass
import
pass_context
from
vllm.compilation.partition_rules
import
(
from
vllm.compilation.partition_rules
import
(
inductor_partition_rule_context
,
inductor_partition_rule_context
,
resolve_defined_ops
,
should_split
,
)
)
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -303,7 +303,7 @@ class SplitItem:
...
@@ -303,7 +303,7 @@ class SplitItem:
def
split_graph
(
def
split_graph
(
graph
:
fx
.
GraphModule
,
resolved
_ops
:
list
[
torch
.
_ops
.
OpOverload
]
graph
:
fx
.
GraphModule
,
splitting
_ops
:
list
[
str
]
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
)
->
tuple
[
fx
.
GraphModule
,
list
[
SplitItem
]]:
# split graph by ops
# split graph by ops
subgraph_id
=
0
subgraph_id
=
0
...
@@ -312,12 +312,8 @@ def split_graph(
...
@@ -312,12 +312,8 @@ def split_graph(
for
node
in
graph
.
graph
.
nodes
:
for
node
in
graph
.
graph
.
nodes
:
if
node
.
op
in
(
"output"
,
"placeholder"
):
if
node
.
op
in
(
"output"
,
"placeholder"
):
continue
continue
# Match node.target against resolved_ops
# node.target can be OpOverloadPacket, need to check .default
if
should_split
(
node
,
splitting_ops
):
if
node
.
op
==
"call_function"
and
(
node
.
target
in
resolved_ops
or
(
hasattr
(
node
.
target
,
"default"
)
and
node
.
target
.
default
in
resolved_ops
)
):
subgraph_id
+=
1
subgraph_id
+=
1
node_to_subgraph_id
[
node
]
=
subgraph_id
node_to_subgraph_id
[
node
]
=
subgraph_id
split_op_graphs
.
append
(
subgraph_id
)
split_op_graphs
.
append
(
subgraph_id
)
...
@@ -653,8 +649,7 @@ class VllmBackend:
...
@@ -653,8 +649,7 @@ class VllmBackend:
else
:
else
:
fx_split_ops
=
self
.
compilation_config
.
splitting_ops
or
[]
fx_split_ops
=
self
.
compilation_config
.
splitting_ops
or
[]
resolved_split_ops
=
resolve_defined_ops
(
fx_split_ops
)
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
fx_split_ops
)
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
graph
,
resolved_split_ops
)
from
torch._dynamo.utils
import
lazy_format_graph_code
from
torch._dynamo.utils
import
lazy_format_graph_code
...
...
vllm/compilation/partition_rules.py
View file @
b158df28
...
@@ -2,54 +2,39 @@
...
@@ -2,54 +2,39 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
contextlib
import
logging
import
torch
import
torch
from
torch._library.utils
import
lookup_op
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
resolve_defined_ops
(
op_names
:
list
[
str
])
->
list
[
"torch._ops.OpOverload"
]:
def
should_split
(
node
:
torch
.
fx
.
Node
,
splitting_ops
:
list
[
str
])
->
bool
:
"""Resolve operator names to OpOverload objects.
"""
Check if a node should be split for dynamo graph partition.
It operates on dynamo graph, so the node.target can be anything.
We need to check and split only on OpOverload and OpOverloadPacket.
"""
Skips operators that fail to resolve (e.g., operators not registered or
if
node
.
op
!=
"call_function"
:
model-specific operators not present in the current model).
return
False
Note: Users should inspect the operator graph before lowering and ensure
target
=
node
.
target
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args
:
if
isinstance
(
target
,
torch
.
_ops
.
OpOverloadPacket
)
:
op_names: List of operator names in PyTorch format
# Example: "aten::add"
(e.g., "vllm::unified_attention")
return
target
.
_qualified_op_name
in
splitting_ops
Returns:
if
isinstance
(
target
,
torch
.
_ops
.
OpOverload
):
List of successfully resolved operator overloads
# Example: "aten::add"
"""
packet_name
=
target
.
name
()
resolved
=
[]
for
op_name
in
op_names
:
# Example: "aten::add.default"
try
:
op_overload_name
=
f
"
{
packet_name
}
.
{
target
.
_overloadname
}
"
resolved
.
append
(
lookup_op
(
op_name
))
return
op_overload_name
in
splitting_ops
or
packet_name
in
splitting_ops
except
Exception
:
# Skip operators that don't exist (e.g., model-specific ops)
return
False
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from
vllm.config
import
CompilationConfig
logger
.
log
(
logging
.
DEBUG
if
op_name
in
CompilationConfig
.
_attention_ops
else
logging
.
WARNING
,
"Failed to resolve operator for CUDAGraph partition: %s"
,
op_name
,
)
continue
return
resolved
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment