Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ac7713e
Unverified
Commit
4ac7713e
authored
Jul 23, 2025
by
Yong Hoon Shin
Committed by
GitHub
Jul 23, 2025
Browse files
Add test case for compiling multiple graphs (#21044)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
8560a5b2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
390 additions
and
1 deletion
+390
-1
tests/compile/piecewise/test_multiple_graphs.py
tests/compile/piecewise/test_multiple_graphs.py
+350
-0
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+6
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+34
-1
No files found.
tests/compile/piecewise/test_multiple_graphs.py
0 → 100644
View file @
4ac7713e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
"""
import
torch
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
support_torch_compile
)
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.envs
import
VLLM_USE_V1
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
silly_lib
=
Library
(
"silly"
,
"FRAGMENT"
)
# noqa
BATCH_SIZE
=
32
MLP_SIZE
=
128
HIDDEN_SIZE
=
1024
RANDOM_SEED
=
0
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
q
)
out
+=
k
out
+=
v
def
silly_attention_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"attention"
,
op_func
=
silly_attention
,
mutates_args
=
[
"out"
],
fake_impl
=
silly_attention_fake
,
target_lib
=
silly_lib
,
)
@
support_torch_compile
class
ParentModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
mlp_size
:
int
,
hidden_size
:
int
)
->
None
:
super
().
__init__
()
self
.
pre_attn
=
nn
.
Linear
(
mlp_size
,
hidden_size
,
bias
=
False
)
self
.
post_attn
=
nn
.
Linear
(
hidden_size
,
mlp_size
,
bias
=
False
)
self
.
rms_norm_weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
# Initialize to same weights for testing
nn
.
init
.
xavier_normal_
(
self
.
pre_attn
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
gain
=
0.001
)
nn
.
init
.
xavier_normal_
(
self
.
post_attn
.
weight
.
data
,
generator
=
torch
.
Generator
().
manual_seed
(
RANDOM_SEED
),
gain
=
0.001
)
def
rms_norm_ref
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_f32
=
x
.
float
()
return
(
x_f32
*
torch
.
rsqrt
(
torch
.
mean
(
x_f32
.
square
(),
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
*
self
.
rms_norm_weight
).
to
(
x
.
dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
pre_attn
(
x
)
x
=
self
.
rms_norm_ref
(
x
)
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
self
.
rms_norm_ref
(
x
)
x
=
self
.
post_attn
(
x
)
return
x
@
support_torch_compile
class
CompiledAttention
(
nn
.
Module
):
def
__init__
(
self
,
*
,
mlp_size
:
int
,
hidden_size
:
int
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
attn
=
Attention
(
mlp_size
,
hidden_size
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
attn
(
x
)
@
support_torch_compile
class
CompiledAttentionTwo
(
CompiledAttention
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
attn
(
x
)
+
x
@
ignore_torch_compile
class
SimpleModelWithTwoGraphs
(
ParentModel
):
def
__init__
(
self
,
*
,
mlp_size
:
int
,
hidden_size
:
int
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
# Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)"
# This is because CompiledAttention and CompiledAttentionTwo
# have different implmentations but the same torch.compile
# cache dir will be used as default prefix is 'model_tag'
with
set_model_tag
(
"attn_one"
):
self
.
attn_one
=
CompiledAttention
(
mlp_size
=
mlp_size
,
hidden_size
=
hidden_size
,
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.attn_one"
,
)
with
set_model_tag
(
"attn_two"
):
self
.
attn_two
=
CompiledAttentionTwo
(
mlp_size
=
mlp_size
,
hidden_size
=
hidden_size
,
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.attn_two"
,
)
self
.
hidden_states
=
torch
.
zeros
((
BATCH_SIZE
,
MLP_SIZE
)).
cuda
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bsz
=
x
.
shape
[
0
]
# CUDAGraph expects same tensor addresses for each run
self
.
hidden_states
[:
bsz
].
copy_
(
x
)
x
=
self
.
attn_one
(
self
.
hidden_states
[:
bsz
])
self
.
hidden_states
[:
bsz
].
copy_
(
x
)
x
=
self
.
attn_two
(
self
.
hidden_states
[:
bsz
])
return
x
def
test_ignore_torch_compile_decorator
():
assert
VLLM_USE_V1
# piecewise
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
@
support_torch_compile
class
A
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
x
*
3
return
x
@
ignore_torch_compile
class
B
(
A
):
...
@
support_torch_compile
class
C
(
B
):
...
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# A has support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
),
set_forward_context
({},
vllm_config
=
vllm_config
):
# first run is for compile
mod_A
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
# run cudagraph captured sizes
mod_A
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_A
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
with
set_current_vllm_config
(
vllm_config
):
mod_B
=
B
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# B's ignore_torch_compile should override A's support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
),
set_forward_context
({},
vllm_config
=
vllm_config
):
mod_B
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
mod_B
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_B
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
with
set_current_vllm_config
(
vllm_config
):
mod_C
=
C
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# C's support_torch_compile should override B's ignore_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
),
set_forward_context
({},
vllm_config
=
vllm_config
):
mod_C
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
mod_C
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_C
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
@
torch
.
inference_mode
def
run_model
(
vllm_config
,
model
:
nn
.
Module
,
inputs
:
torch
.
Tensor
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
# First run is for compile
model
(
inputs
)
# Run CUDAGraph captured sizes
model
(
inputs
[:
2
])
model
(
inputs
[:
1
])
output
=
model
(
inputs
[:
2
])
output
=
output
.
cpu
()
return
output
.
cpu
()
def
test_multi_graph_piecewise_compile_outputs_equal
():
outputs
=
[]
# piecewise compile
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
inputs
=
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
()
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
# two graphs for the model
num_piecewise_graphs_seen
=
6
,
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen
=
4
,
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations
=
4
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
# no compile or cudagraph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
))
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
# piecewise compile without CUDA graph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
False
,
splitting_ops
=
[
"silly.attention"
],
))
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
hidden_size
=
HIDDEN_SIZE
,
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
num_piecewise_graphs_seen
=
6
,
num_piecewise_capturable_graphs_seen
=
4
,
num_backend_compilations
=
4
,
num_cudagraph_captured
=
0
,
# no cudagraph captured
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
1
])
# Expect bitwise equivalence using inductor w/ and w/o cudagraph
assert
torch
.
equal
(
outputs
[
0
],
outputs
[
2
])
vllm/compilation/compiler_interface.py
View file @
4ac7713e
...
...
@@ -423,6 +423,12 @@ class InductorAdaptor(CompilerInterface):
if
is_torch_equal_or_newer
(
"2.6"
):
stack
.
enter_context
(
torch
.
_inductor
.
config
.
patch
(
fx_graph_remote_cache
=
False
))
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_autograd_cache
=
False
))
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_remote_autograd_cache
=
False
))
...
...
vllm/compilation/decorators.py
View file @
4ac7713e
...
...
@@ -20,9 +20,38 @@ from .monitor import start_monitoring_torch_compile
logger
=
init_logger
(
__name__
)
IGNORE_COMPILE_KEY
=
"_ignore_compile_vllm"
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
def
ignore_torch_compile
(
cls
:
_T
)
->
_T
:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.
If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.
If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr
(
cls
,
IGNORE_COMPILE_KEY
,
True
)
return
cls
def
_should_ignore_torch_compile
(
cls
)
->
bool
:
"""
Check if the class should be ignored for torch.compile.
"""
return
getattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
@
overload
def
support_torch_compile
(
*
,
...
...
@@ -148,6 +177,8 @@ def _support_torch_compile(
old_init
=
cls
.
__init__
setattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
vllm_config
=
vllm_config
...
...
@@ -156,9 +187,11 @@ def _support_torch_compile(
self
.
do_not_compile
=
\
vllm_config
.
compilation_config
.
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
]
or
not
supports_dynamo
()
or
_should_ignore_torch_compile
(
self
.
__class__
)
if
self
.
do_not_compile
:
return
compilation_counter
.
num_models_seen
+=
1
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
,
compilation_level
=
vllm_config
.
compilation_config
.
level
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment