Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dfd23820
Unverified
Commit
dfd23820
authored
Aug 20, 2025
by
Yong Hoon Shin
Committed by
GitHub
Aug 20, 2025
Browse files
[torch.compile] Support conditional torch.compile per module (#22269)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
3b11b26b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
308 additions
and
103 deletions
+308
-103
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/compile/piecewise/test_multiple_graphs.py
tests/compile/piecewise/test_multiple_graphs.py
+36
-101
tests/compile/test_decorator.py
tests/compile/test_decorator.py
+251
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+19
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
dfd23820
...
@@ -328,6 +328,7 @@ steps:
...
@@ -328,6 +328,7 @@ steps:
-
pytest -v -s compile/test_sequence_parallelism.py
-
pytest -v -s compile/test_sequence_parallelism.py
-
pytest -v -s compile/test_async_tp.py
-
pytest -v -s compile/test_async_tp.py
-
pytest -v -s compile/test_fusion_all_reduce.py
-
pytest -v -s compile/test_fusion_all_reduce.py
-
pytest -v -s compile/test_decorator.py
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
mirror_hardwares
:
[
amdexperimental
]
mirror_hardwares
:
[
amdexperimental
]
...
@@ -341,6 +342,7 @@ steps:
...
@@ -341,6 +342,7 @@ steps:
-
pytest -v -s compile/piecewise/test_simple.py
-
pytest -v -s compile/piecewise/test_simple.py
-
pytest -v -s compile/piecewise/test_toy_llama.py
-
pytest -v -s compile/piecewise/test_toy_llama.py
-
pytest -v -s compile/piecewise/test_full_cudagraph.py
-
pytest -v -s compile/piecewise/test_full_cudagraph.py
-
pytest -v -s compile/piecewise/test_multiple_graphs.py
-
label
:
PyTorch Fullgraph Test
# 18min
-
label
:
PyTorch Fullgraph Test
# 18min
mirror_hardwares
:
[
amdexperimental
]
mirror_hardwares
:
[
amdexperimental
]
...
...
tests/compile/piecewise/test_multiple_graphs.py
View file @
dfd23820
...
@@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag
...
@@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
support_torch_compile
)
support_torch_compile
)
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
set_current_vllm_config
)
VllmConfig
,
set_current_vllm_config
)
from
vllm.envs
import
VLLM_USE_V1
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
# create a library to hold the custom op
...
@@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel):
...
@@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel):
return
x
return
x
def
test_ignore_torch_compile_decorator
():
assert
VLLM_USE_V1
# piecewise
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
@
support_torch_compile
class
A
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
x
*
3
return
x
@
ignore_torch_compile
class
B
(
A
):
...
@
support_torch_compile
class
C
(
B
):
...
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# A has support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
),
set_forward_context
({},
vllm_config
=
vllm_config
):
# first run is for compile
mod_A
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
# run cudagraph captured sizes
mod_A
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_A
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
with
set_current_vllm_config
(
vllm_config
):
mod_B
=
B
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# B's ignore_torch_compile should override A's support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
),
set_forward_context
({},
vllm_config
=
vllm_config
):
mod_B
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
mod_B
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_B
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
with
set_current_vllm_config
(
vllm_config
):
mod_C
=
C
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# C's support_torch_compile should override B's ignore_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
),
set_forward_context
({},
vllm_config
=
vllm_config
):
mod_C
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
mod_C
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
mod_C
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
@
torch
.
inference_mode
@
torch
.
inference_mode
def
run_model
(
vllm_config
,
model
:
nn
.
Module
,
inputs
:
torch
.
Tensor
):
def
run_model
(
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
inputs
:
torch
.
Tensor
,
cudagraph_runtime_mode
:
CUDAGraphMode
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
#
First run is for compile
#
warmup for the model with cudagraph_mode NONE
model
(
inputs
)
model
(
inputs
)
# Run CUDAGraph captured sizes
# simulate cudagraphs capturing
model
(
inputs
[:
2
])
with
set_forward_context
({},
model
(
inputs
[:
1
])
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
output
=
model
(
inputs
[:
2
])
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
model
(
inputs
[:
2
])
with
set_forward_context
({},
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
)):
model
(
inputs
[:
1
])
# simulate cudagraphs replay
with
set_forward_context
({},
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
output
=
model
(
inputs
[:
2
])
output
=
output
.
cpu
()
output
=
output
.
cpu
()
return
output
.
cpu
()
return
output
.
cpu
()
...
@@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
...
@@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
))
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
...
@@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
...
@@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
num_cudagraph_captured
=
8
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# no compile or cudagraph
# no compile or cudagraph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
))
level
=
CompilationLevel
.
NO_COMPILATION
,
))
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
...
@@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
...
@@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
num_backend_compilations
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
num_cudagraph_captured
=
0
,
):
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# piecewise compile without CUDA graph
# piecewise compile without CUDA graph
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
...
@@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
...
@@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
use_cudagraph
=
False
,
use_cudagraph
=
False
,
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly.attention"
],
))
))
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
model
=
SimpleModelWithTwoGraphs
(
mlp_size
=
MLP_SIZE
,
...
@@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
...
@@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
num_backend_compilations
=
4
,
num_backend_compilations
=
4
,
num_cudagraph_captured
=
0
,
# no cudagraph captured
num_cudagraph_captured
=
0
,
# no cudagraph captured
):
):
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
))
outputs
.
append
(
run_model
(
vllm_config
,
model
,
inputs
,
cudagraph_runtime_mode
))
# Generally don't expect outputs with and without inductor
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
# to be bitwise equivalent
...
...
tests/compile/test_decorator.py
0 → 100644
View file @
dfd23820
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
(
ignore_torch_compile
,
support_torch_compile
)
from
vllm.config
import
(
CacheConfig
,
CompilationConfig
,
CompilationLevel
,
CUDAGraphMode
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.forward_context
import
BatchDescriptor
,
set_forward_context
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
silly_lib
=
Library
(
"silly"
,
"FRAGMENT"
)
# noqa
BATCH_SIZE
=
32
MLP_SIZE
=
128
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
q
)
out
+=
k
out
+=
v
def
silly_attention_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"attention"
,
op_func
=
silly_attention
,
mutates_args
=
[
"out"
],
fake_impl
=
silly_attention_fake
,
target_lib
=
silly_lib
,
)
@
torch
.
inference_mode
def
run_model
(
vllm_config
:
VllmConfig
,
model
:
nn
.
Module
,
cudagraph_runtime_mode
:
CUDAGraphMode
):
with
set_forward_context
({},
vllm_config
=
vllm_config
):
# warmup for the model with cudagraph_mode NONE
model
(
torch
.
randn
(
BATCH_SIZE
,
MLP_SIZE
).
cuda
())
# simulate cudagraphs capturing
with
set_forward_context
({},
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
with
set_forward_context
({},
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
1
,
)):
model
(
torch
.
randn
(
1
,
MLP_SIZE
).
cuda
())
# simulate cudagraphs replay
with
set_forward_context
({},
vllm_config
=
vllm_config
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
batch_descriptor
=
BatchDescriptor
(
num_tokens
=
2
,
)):
output
=
model
(
torch
.
randn
(
2
,
MLP_SIZE
).
cuda
())
output
=
output
.
cpu
()
return
output
.
cpu
()
def
test_ignore_torch_compile_decorator
():
# piecewise
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
@
support_torch_compile
class
A
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
x
*
3
return
x
@
ignore_torch_compile
class
B
(
A
):
...
@
support_torch_compile
class
C
(
B
):
...
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# A has support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
with
set_current_vllm_config
(
vllm_config
):
mod_B
=
B
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# B's ignore_torch_compile should override A's support_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
0
,
num_piecewise_graphs_seen
=
0
,
num_piecewise_capturable_graphs_seen
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_captured
=
0
,
):
run_model
(
vllm_config
,
mod_B
,
cudagraph_runtime_mode
)
with
set_current_vllm_config
(
vllm_config
):
mod_C
=
C
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# C's support_torch_compile should override B's ignore_torch_compile
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
3
,
num_piecewise_capturable_graphs_seen
=
2
,
num_backend_compilations
=
2
,
num_cudagraph_captured
=
4
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model
(
vllm_config
,
mod_C
,
cudagraph_runtime_mode
)
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
B
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
x
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
x
+
x
return
x
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
A
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
mod1
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
mod2
=
B
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
mod1
(
x
)
attn_output
=
torch
.
empty_like
(
x
)
torch
.
ops
.
silly
.
attention
(
x
,
x
,
x
,
attn_output
)
x
=
attn_output
x
=
self
.
mod2
(
x
)
return
x
def
test_conditional_compile_enable_if
():
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
kv_sharing_fast_prefill
=
True
,
),
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
cudagraph_runtime_mode
=
CUDAGraphMode
.
PIECEWISE
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
with
compilation_counter
.
expect
(
num_graphs_seen
=
2
,
num_piecewise_graphs_seen
=
6
,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen
=
4
,
num_backend_compilations
=
4
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
# Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(
kv_sharing_fast_prefill
=
False
,
),
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
splitting_ops
=
[
"silly.attention"
],
cudagraph_capture_sizes
=
[
1
,
2
],
))
with
set_current_vllm_config
(
vllm_config
):
mod_A
=
A
(
vllm_config
=
vllm_config
,
prefix
=
''
).
eval
().
cuda
()
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
num_piecewise_graphs_seen
=
7
,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen
=
4
,
num_backend_compilations
=
4
,
num_cudagraph_captured
=
8
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model
(
vllm_config
,
mod_A
,
cudagraph_runtime_mode
)
vllm/compilation/decorators.py
View file @
dfd23820
...
@@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool:
...
@@ -52,6 +52,14 @@ def _should_ignore_torch_compile(cls) -> bool:
return
getattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
return
getattr
(
cls
,
IGNORE_COMPILE_KEY
,
False
)
@
overload
def
support_torch_compile
(
*
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Callable
[[
_T
],
_T
]:
...
@
overload
@
overload
def
support_torch_compile
(
def
support_torch_compile
(
*
,
*
,
...
@@ -69,6 +77,7 @@ def support_torch_compile(
...
@@ -69,6 +77,7 @@ def support_torch_compile(
cls
:
Optional
[
_T
]
=
None
,
cls
:
Optional
[
_T
]
=
None
,
*
,
*
,
dynamic_arg_dims
:
Optional
[
dict
[
str
,
Union
[
int
,
list
[
int
]]]]
=
None
,
dynamic_arg_dims
:
Optional
[
dict
[
str
,
Union
[
int
,
list
[
int
]]]]
=
None
,
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
Union
[
Callable
[[
_T
],
_T
],
_T
]:
)
->
Union
[
Callable
[[
_T
],
_T
],
_T
]:
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
...
@@ -118,6 +127,11 @@ def support_torch_compile(
...
@@ -118,6 +127,11 @@ def support_torch_compile(
NOTE: if an argument is `None`, it should always be passed as `None` during
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
computation graph.
`enable_if` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""
"""
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
...
@@ -149,7 +163,8 @@ def support_torch_compile(
...
@@ -149,7 +163,8 @@ def support_torch_compile(
if
k
not
in
sig
.
parameters
:
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
,
enable_if
)
if
cls
is
not
None
:
if
cls
is
not
None
:
# use `support_torch_compile` as a decorator without arguments
# use `support_torch_compile` as a decorator without arguments
...
@@ -162,6 +177,7 @@ def support_torch_compile(
...
@@ -162,6 +177,7 @@ def support_torch_compile(
def
_support_torch_compile
(
def
_support_torch_compile
(
cls
:
_T
,
cls
:
_T
,
dynamic_arg_dims
:
dict
[
str
,
Union
[
int
,
list
[
int
]]],
dynamic_arg_dims
:
dict
[
str
,
Union
[
int
,
list
[
int
]]],
enable_if
:
Optional
[
Callable
[[
VllmConfig
],
bool
]]
=
None
,
)
->
_T
:
)
->
_T
:
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
...
@@ -182,13 +198,14 @@ def _support_torch_compile(
...
@@ -182,13 +198,14 @@ def _support_torch_compile(
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
enable_compile
=
enable_if
is
None
or
enable_if
(
vllm_config
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
\
self
.
do_not_compile
=
\
vllm_config
.
compilation_config
.
level
in
[
vllm_config
.
compilation_config
.
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
or
_should_ignore_torch_compile
(
]
or
not
supports_dynamo
()
or
_should_ignore_torch_compile
(
self
.
__class__
)
self
.
__class__
)
or
not
enable_compile
if
self
.
do_not_compile
:
if
self
.
do_not_compile
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment