Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5bcc153d
Unverified
Commit
5bcc153d
authored
Sep 16, 2025
by
Jiangyun Zhu
Committed by
GitHub
Sep 15, 2025
Browse files
[Compile] Fix noop_elimination pass and add tests for noop_elimination (#24880)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
45bfa49c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
130 additions
and
23 deletions
+130
-23
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/compile/backend.py
tests/compile/backend.py
+5
-1
tests/compile/test_noop_elimination.py
tests/compile/test_noop_elimination.py
+106
-0
vllm/compilation/noop_elimination.py
vllm/compilation/noop_elimination.py
+18
-22
No files found.
.buildkite/test-pipeline.yaml
View file @
5bcc153d
...
...
@@ -394,6 +394,7 @@ steps:
-
pytest -v -s compile/test_async_tp.py
-
pytest -v -s compile/test_fusion_all_reduce.py
-
pytest -v -s compile/test_decorator.py
-
pytest -v -s compile/test_noop_elimination.py
-
label
:
PyTorch Fullgraph Smoke Test
# 15min
timeout_in_minutes
:
30
...
...
tests/compile/backend.py
View file @
5bcc153d
...
...
@@ -65,3 +65,7 @@ class TestBackend:
num_post
=
len
(
list
(
find_op_nodes
(
op
,
self
.
graph_post_pass
)))
assert
num_pre
==
0
,
f
"Unexpected op
{
op
.
name
()
}
in pre-pass graph"
assert
num_post
>
0
,
f
"Op
{
op
.
name
()
}
not found in post-pass graph"
def
op_count
(
self
,
op
:
OpOverload
,
before
=
False
)
->
int
:
graph
=
self
.
graph_pre_pass
if
before
else
self
.
graph_post_pass
return
len
(
list
(
find_op_nodes
(
op
,
graph
)))
tests/compile/test_noop_elimination.py
0 → 100644
View file @
5bcc153d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
vllm
from
vllm.compilation.noop_elimination
import
NoOpEliminationPass
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
PassConfig
,
VllmConfig
)
from
.backend
import
TestBackend
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
,
4096
])
def
test_noop_elimination
(
dtype
,
num_tokens
,
hidden_size
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
class
Model
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
# Chain of reshapes
y
=
x
.
reshape
(
-
1
,
128
,
32
)
z
=
y
.
reshape
(
-
1
,
4096
)
# No-op reshape
a
=
z
.
reshape
(
-
1
,
4096
)
# Final reshape that should remain
b
=
a
.
reshape
(
-
1
,
128
,
32
)
# No-op slice
c
=
b
[
0
:
b
.
shape
[
0
]]
# The pass should replace the result of this op with `c`.
d
=
torch
.
slice_scatter
(
torch
.
ones_like
(
c
),
# Dummy tensor to be scattered into
c
,
# Source tensor
0
,
# dim
0
,
# start
c
.
shape
[
0
],
# end
)
return
d
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
pass_config
=
PassConfig
(
enable_noop
=
True
),
))
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
backend
=
TestBackend
(
noop_pass
)
model
=
Model
()
# First dimension dynamic
x
=
torch
.
rand
(
num_tokens
,
hidden_size
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
result
=
model
(
x
)
model2
=
torch
.
compile
(
model
,
backend
=
backend
)
result2
=
model2
(
x
)
ATOL
,
RTOL
=
(
2e-3
,
2e-3
)
torch
.
testing
.
assert_close
(
result
,
result2
,
atol
=
ATOL
,
rtol
=
RTOL
)
# The no-op reshape and slice should be eliminated.
# The chain of reshapes should be fused into a single reshape.
assert
backend
.
op_count
(
torch
.
ops
.
aten
.
reshape
.
default
)
==
1
assert
backend
.
op_count
(
torch
.
ops
.
aten
.
slice
.
Tensor
)
==
0
assert
backend
.
op_count
(
torch
.
ops
.
aten
.
slice_scatter
.
default
)
==
0
def
test_non_noop_slice_preserved
():
"""Ensure that a slice with end=-1 (dropping last row) is NOT eliminated.
Regression test for a bug where end=-1 was treated like an inferred
dimension (reshape semantics) leading to incorrect elimination.
"""
torch
.
set_default_device
(
"cuda"
)
x
=
torch
.
randn
(
16
,
16
)
class
SliceModel
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
base
=
x
.
clone
()
src
=
torch
.
ones
(
15
,
16
)
y
=
torch
.
slice_scatter
(
base
,
src
,
dim
=
0
,
start
=
0
,
end
=-
1
)
return
x
[
0
:
-
1
,
:],
y
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
pass_config
=
PassConfig
(
enable_noop
=
True
),
))
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
noop_pass
=
NoOpEliminationPass
(
vllm_config
)
backend
=
TestBackend
(
noop_pass
)
model
=
SliceModel
()
ref
=
model
(
x
)
compiled
=
torch
.
compile
(
model
,
backend
=
backend
)
out
=
compiled
(
x
)
torch
.
testing
.
assert_close
(
ref
,
out
)
# The slice should remain (not a no-op).
assert
backend
.
op_count
(
torch
.
ops
.
aten
.
slice
.
Tensor
)
==
1
assert
backend
.
op_count
(
torch
.
ops
.
aten
.
slice_scatter
.
default
)
==
1
vllm/compilation/noop_elimination.py
View file @
5bcc153d
...
...
@@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass):
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
TODO(luka): This is currently tested in test_fusion,
but separate tests could be good.
"""
def
__call__
(
self
,
graph
:
torch
.
fx
.
Graph
):
...
...
@@ -96,17 +93,19 @@ class NoOpEliminationPass(VllmInductorPass):
# Invalid reshape args, skip
continue
if
self
.
all_dims_equivalent
(
shape
,
input_shape
):
if
self
.
reshape_
all_dims_equivalent
(
shape
,
input_shape
):
node
.
replace_all_uses_with
(
input
)
graph
.
erase_node
(
node
)
count
+=
1
elif
is_func
(
node
,
torch
.
ops
.
aten
.
slice
.
Tensor
):
# python slicing semantics are different from reshape
# Don't treat -1 as inferred dimension
input
,
dim_index
,
start
,
end
=
node
.
args
[:
4
]
input_shape
=
input
.
meta
[
"val"
].
shape
i_dim
=
input_shape
[
dim_index
]
output_shape
=
node
.
meta
[
"val"
].
shape
if
start
==
0
and
self
.
dims_equivalent
(
end
,
i_dim
)
:
if
output_shape
==
input_shape
:
node
.
replace_all_uses_with
(
input
)
graph
.
erase_node
(
node
)
count
+=
1
...
...
@@ -116,14 +115,7 @@ class NoOpEliminationPass(VllmInductorPass):
base_shape
=
base
.
meta
[
"val"
].
shape
view_shape
=
view
.
meta
[
"val"
].
shape
view_dim
=
view_shape
[
dim_index
]
# Check that view fully covers base and the full view is used
# (if the view fully covered the base after slicing but was not
# fully used, we could replace slice_scatter with a simple slice
# but that's a niche case).
if
(
base_shape
==
view_shape
and
start
==
0
and
self
.
dims_equivalent
(
end
,
view_dim
)):
if
base_shape
==
view_shape
:
node
.
replace_all_uses_with
(
view
)
graph
.
erase_node
(
node
)
count
+=
1
...
...
@@ -132,12 +124,8 @@ class NoOpEliminationPass(VllmInductorPass):
self
.
dump_graph
(
graph
,
"after_noop_elimination"
)
self
.
end_and_log
()
def
all_dims_equivalent
(
self
,
dims
:
Iterable
[
Union
[
int
,
torch
.
fx
.
Node
]],
i_dims
:
Iterable
[
Union
[
int
,
SymInt
]]):
return
all
(
self
.
dims_equivalent
(
s
,
i_s
)
for
s
,
i_s
in
zip
(
dims
,
i_dims
))
def
dims_equivalent
(
self
,
dim
:
Union
[
int
,
torch
.
fx
.
Node
],
# ---------------------- Reshape helpers ----------------------
def
reshape_dims_equivalent
(
self
,
dim
:
Union
[
int
,
torch
.
fx
.
Node
],
i_dim
:
Union
[
int
,
SymInt
])
->
bool
:
"""
This function checks if two dimensions are equivalent.
...
...
@@ -156,10 +144,18 @@ class NoOpEliminationPass(VllmInductorPass):
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if
dim
==
i_dim
or
dim
==
-
1
:
return
True
# Case 3
return
isinstance
(
dim
,
torch
.
fx
.
Node
)
and
dim
.
meta
[
"val"
]
==
i_dim
def
reshape_all_dims_equivalent
(
self
,
dims
:
Iterable
[
Union
[
int
,
torch
.
fx
.
Node
]],
i_dims
:
Iterable
[
Union
[
int
,
SymInt
]],
)
->
bool
:
return
all
(
self
.
reshape_dims_equivalent
(
s
,
i_s
)
for
s
,
i_s
in
zip
(
dims
,
i_dims
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment