Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f0862eae
Unverified
Commit
f0862eae
authored
Oct 14, 2025
by
Boyuan Feng
Committed by
GitHub
Oct 15, 2025
Browse files
[Graph Partition] pass tests for decorator (#26831)
Signed-off-by:
Boyuan Feng
<
boyuan@meta.com
>
parent
8c851f6d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
26 deletions
+74
-26
tests/compile/test_decorator.py
tests/compile/test_decorator.py
+74
-26
No files found.
tests/compile/test_decorator.py
View file @
f0862eae
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -14,6 +15,7 @@ from vllm.config import (
...
@@ -14,6 +15,7 @@ from vllm.config import (
set_current_vllm_config
,
set_current_vllm_config
,
)
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils
import
is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
# This import automatically registers `torch.ops.silly.attention`
from
.
import
silly_attention
# noqa: F401
from
.
import
silly_attention
# noqa: F401
...
@@ -65,19 +67,40 @@ def run_model(
...
@@ -65,19 +67,40 @@ def run_model(
return
output
.
cpu
()
return
output
.
cpu
()
def
test_ignore_torch_compile_decorator
():
@
pytest
.
mark
.
parametrize
(
"use_inductor_graph_partition"
,
[
True
,
False
])
# vllmcompile
def
test_ignore_torch_compile_decorator
(
use_inductor_graph_partition
,
monkeypatch
):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
# piecewise
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
mode
=
CompilationMode
.
VLLM_COMPILE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly::attention"
],
splitting_ops
=
[
"silly::attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
False
,
# TODO test both?
use_inductor_graph_partition
=
use_inductor_graph_partition
,
)
)
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
expected_num_graphs_seen
=
1
expected_num_cudagraph_captured
=
(
4
# num_cudagraph_sizes * num cudagraphs to capture
)
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
1
expected_num_piecewise_capturable_graphs_seen
=
1
expected_num_backend_compilations
=
1
else
:
expected_num_piecewise_graphs_seen
=
3
expected_num_piecewise_capturable_graphs_seen
=
2
expected_num_backend_compilations
=
2
@
support_torch_compile
@
support_torch_compile
class
A
(
nn
.
Module
):
class
A
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -104,12 +127,11 @@ def test_ignore_torch_compile_decorator():
...
@@ -104,12 +127,11 @@ def test_ignore_torch_compile_decorator():
# A has support_torch_compile
# A has support_torch_compile
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
expected_num_graphs_seen
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_capturable_graphs_seen
=
2
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
2
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
4
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
...
@@ -131,12 +153,11 @@ def test_ignore_torch_compile_decorator():
...
@@ -131,12 +153,11 @@ def test_ignore_torch_compile_decorator():
# C's support_torch_compile should override B's ignore_torch_compile
# C's support_torch_compile should override B's ignore_torch_compile
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
expected_num_graphs_seen
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
num_piecewise_capturable_graphs_seen
=
2
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
2
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
4
,
num_cudagraph_captured
=
expected_num_cudagraph_captured
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
run_model
(
vllm_config
,
mod_C
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_C
,
cudagraph_runtime_mode
)
...
@@ -179,7 +200,15 @@ class A(nn.Module):
...
@@ -179,7 +200,15 @@ class A(nn.Module):
return
x
return
x
def
test_conditional_compile_enable_if
():
@
pytest
.
mark
.
parametrize
(
"use_inductor_graph_partition"
,
[
True
,
False
])
def
test_conditional_compile_enable_if
(
use_inductor_graph_partition
,
monkeypatch
):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch
.
setenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"1"
)
if
use_inductor_graph_partition
and
not
is_torch_equal_or_newer
(
"2.9.0.dev"
):
pytest
.
skip
(
"inductor graph partition is only available in PyTorch 2.9+"
)
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
kv_sharing_fast_prefill
=
True
,
kv_sharing_fast_prefill
=
True
,
...
@@ -189,7 +218,7 @@ def test_conditional_compile_enable_if():
...
@@ -189,7 +218,7 @@ def test_conditional_compile_enable_if():
use_cudagraph
=
True
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly::attention"
],
splitting_ops
=
[
"silly::attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
False
,
# TODO test both
use_inductor_graph_partition
=
use_inductor_graph_partition
,
),
),
)
)
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
...
@@ -197,17 +226,26 @@ def test_conditional_compile_enable_if():
...
@@ -197,17 +226,26 @@ def test_conditional_compile_enable_if():
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
2
expected_num_piecewise_capturable_graphs_seen
=
2
expected_num_backend_compilations
=
2
else
:
expected_num_piecewise_graphs_seen
=
6
expected_num_piecewise_capturable_graphs_seen
=
4
expected_num_backend_compilations
=
4
# A has support_torch_compile but enable_if fn returns False
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
# to be compiled
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
num_graphs_seen
=
2
,
num_piecewise_graphs_seen
=
6
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
# 3 piecewise graphs per instance of B()
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen
=
4
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
4
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
8
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num
_piecewise_captur
able
_
graphs
_seen
# num_cudagraph_sizes * num
cudagraph
able
graphs
to capture
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
...
@@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
...
@@ -222,20 +260,30 @@ def test_conditional_compile_enable_if():
use_cudagraph
=
True
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly::attention"
],
splitting_ops
=
[
"silly::attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
use_inductor_graph_partition
=
False
,
# TODO test both?
use_inductor_graph_partition
=
use_inductor_graph_partition
,
),
),
)
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
""
).
eval
().
cuda
()
if
use_inductor_graph_partition
:
expected_num_piecewise_graphs_seen
=
1
expected_num_piecewise_capturable_graphs_seen
=
1
expected_num_backend_compilations
=
1
else
:
# 3 attn ops and 4 non-attn ops
expected_num_piecewise_graphs_seen
=
7
expected_num_piecewise_capturable_graphs_seen
=
4
expected_num_backend_compilations
=
4
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
7
,
num_piecewise_graphs_seen
=
expected_num_piecewise_graphs_seen
,
# 3 attn ops and 4 non-attn ops
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen
=
4
,
num_piecewise_capturable_graphs_seen
=
expected_num_piecewise_capturable_graphs_seen
,
num_backend_compilations
=
4
,
num_backend_compilations
=
expected_num_backend_compilations
,
num_cudagraph_captured
=
8
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num
_piecewise_captur
able
_
graphs
_seen
# num_cudagraph_sizes * num
cudagraph
able
graphs
to capture
):
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment