Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
29057d3b
Unverified
Commit
29057d3b
authored
Apr 17, 2026
by
BadrBasowid
Committed by
GitHub
Apr 16, 2026
Browse files
[Compilation] Add Unit Tests for VllmFusionPatternMatcherPass (#39692)
Signed-off-by:
BadrBasowid
<
badr.basowid@gmail.com
>
parent
219bb5b8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
128 additions
and
0 deletions
+128
-0
tests/compile/passes/test_vllm_fusion_pattern_matcher_pass.py
...s/compile/passes/test_vllm_fusion_pattern_matcher_pass.py
+128
-0
No files found.
tests/compile/passes/test_vllm_fusion_pattern_matcher_pass.py
0 → 100644
View file @
29057d3b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
vllm.config
from
tests.compile.backend
import
TestBackend
from
vllm.platforms
import
current_platform
from
vllm.compilation.passes.vllm_inductor_pass
import
(
VllmFusionPatternMatcherPass
,
VllmPatternMatcherPass
,
VllmPatternReplacement
,
)
from
vllm.config
import
CompilationConfig
,
CompilationMode
,
VllmConfig
class
ReluToAbsPattern
(
VllmPatternReplacement
):
"""Replaces relu(x) with abs(x) — a minimal test fixture."""
@
property
def
pattern
(
self
):
def
_pattern
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
aten
.
relu
.
default
(
x
)
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
aten
.
abs
.
default
(
x
)
return
_replacement
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
empty_fp32
(
4
)]
class
ExpToSqrtPattern
(
VllmPatternReplacement
):
"""A second distinct pattern type — used to test uuid differentiation."""
@
property
def
pattern
(
self
):
def
_pattern
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
aten
.
exp
.
default
(
x
)
return
_pattern
@
property
def
replacement
(
self
):
def
_replacement
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
ops
.
aten
.
sqrt
.
default
(
x
)
return
_replacement
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
empty_fp32
(
4
)]
class
ReluFusionPass
(
VllmFusionPatternMatcherPass
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
,
"test_relu_fusion"
)
self
.
register
(
ReluToAbsPattern
())
class
TwoPatternFusionPass
(
VllmFusionPatternMatcherPass
):
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
,
"test_two_pattern_fusion"
)
self
.
register
(
ReluToAbsPattern
())
self
.
register
(
ExpToSqrtPattern
())
@
pytest
.
fixture
def
vllm_config
():
return
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
),
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Requires CUDA"
)
def
test_register_tracks_patterns
(
vllm_config
):
"""register() appends each VllmPatternReplacement to _pattern_replacements."""
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
single
=
ReluFusionPass
(
vllm_config
)
two
=
TwoPatternFusionPass
(
vllm_config
)
assert
len
(
single
.
_pattern_replacements
)
==
1
assert
len
(
two
.
_pattern_replacements
)
==
2
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Requires CUDA"
)
def
test_uuid_stable
(
vllm_config
):
"""Two instances of the same pass class produce identical uuids."""
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
p1
=
ReluFusionPass
(
vllm_config
)
p2
=
ReluFusionPass
(
vllm_config
)
p3
=
TwoPatternFusionPass
(
vllm_config
)
assert
p1
.
uuid
()
==
p2
.
uuid
()
assert
p1
.
uuid
()
!=
p3
.
uuid
()
assert
p2
.
uuid
()
!=
p3
.
uuid
()
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Requires CUDA"
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
])
def
test_matched_count_and_match_table
(
vllm_config
,
N
):
"""matched_count and match_table reflect the number of matched patterns."""
class
Model
(
torch
.
nn
.
Module
):
def
forward
(
self
,
*
inputs
):
# N independent relus
return
sum
(
torch
.
relu
(
x
)
for
x
in
inputs
)
with
vllm
.
config
.
set_current_vllm_config
(
vllm_config
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
torch
.
float32
)
fusion_pass
=
ReluFusionPass
(
vllm_config
)
backend
=
TestBackend
(
fusion_pass
)
model
=
torch
.
compile
(
Model
(),
backend
=
backend
)
inputs
=
[
torch
.
rand
(
8
)
for
_
in
range
(
N
)]
model
(
*
inputs
)
assert
fusion_pass
.
matched_count
==
N
assert
VllmPatternMatcherPass
.
match_table
[
"test_relu_fusion"
]
>=
N
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment