Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
26b4fa45
Unverified
Commit
26b4fa45
authored
May 28, 2025
by
Richard Zou
Committed by
GitHub
May 29, 2025
Browse files
Add ability to use CUDAGraphs with use_inductor=False (#17345)
Signed-off-by:
rzou
<
zou3519@gmail.com
>
parent
515b413e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
13 deletions
+51
-13
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+10
-1
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+33
-7
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+4
-0
vllm/compilation/counter.py
vllm/compilation/counter.py
+4
-0
vllm/config.py
vllm/config.py
+0
-5
No files found.
tests/compile/piecewise/test_simple.py
View file @
26b4fa45
...
@@ -74,11 +74,12 @@ class SillyModel(nn.Module):
...
@@ -74,11 +74,12 @@ class SillyModel(nn.Module):
return
x
return
x
def
test_simple_piecewise_compile
():
def
_
test_simple_piecewise_compile
(
*
,
use_inductor
):
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
use_inductor
=
use_inductor
,
splitting_ops
=
[
"silly.attention"
],
splitting_ops
=
[
"silly.attention"
],
cudagraph_copy_inputs
=
True
,
cudagraph_copy_inputs
=
True
,
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
...
@@ -108,3 +109,11 @@ def test_simple_piecewise_compile():
...
@@ -108,3 +109,11 @@ def test_simple_piecewise_compile():
output
=
model
(
input
)
output
=
model
(
input
)
assert
global_counter
==
2
assert
global_counter
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
def
test_simple_piecewise_compile_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
True
)
def
test_simple_piecewise_compile_no_inductor
():
_test_simple_piecewise_compile
(
use_inductor
=
False
)
tests/compile/piecewise/test_toy_llama.py
View file @
26b4fa45
...
@@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
...
@@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor,
@
torch
.
inference_mode
@
torch
.
inference_mode
def
run_model
(
llama_config
,
def
run_model
(
llama_config
,
use_compile
:
bool
,
use_compile
:
bool
,
use_inductor
:
bool
,
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
if
use_compile
:
if
use_compile
:
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
use_inductor
=
use_inductor
,
cudagraph_capture_sizes
=
[
1
,
2
],
cudagraph_capture_sizes
=
[
1
,
2
],
)
)
if
split_attn
:
if
split_attn
:
...
@@ -304,7 +306,7 @@ def run_model(llama_config,
...
@@ -304,7 +306,7 @@ def run_model(llama_config,
return
output
.
cpu
()
return
output
.
cpu
()
def
test_toy_llama
():
def
_
test_toy_llama
(
*
,
use_inductor
):
# compare output with and without piecewise compilation
# compare output with and without piecewise compilation
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
llama_config
=
LlamaConfig
(
hidden_size
=
128
,
...
@@ -326,8 +328,14 @@ def test_toy_llama():
...
@@ -326,8 +328,14 @@ def test_toy_llama():
num_backend_compilations
=
0
,
num_backend_compilations
=
0
,
num_cudagraph_caputured
=
0
,
num_cudagraph_caputured
=
0
,
):
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
False
))
outputs
.
append
(
run_model
(
tractable_config
,
use_compile
=
False
)
run_model
(
llama_config
,
use_inductor
=
False
,
use_compile
=
False
))
run_model
(
tractable_config
,
use_inductor
=
False
,
use_compile
=
False
)
if
use_inductor
:
kwargs
=
{
"num_inductor_compiles"
:
1
,
"num_eager_compiles"
:
0
}
else
:
kwargs
=
{
"num_eager_compiles"
:
1
,
"num_inductor_compiles"
:
0
}
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
...
@@ -336,9 +344,13 @@ def test_toy_llama():
...
@@ -336,9 +344,13 @@ def test_toy_llama():
num_backend_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_backend_compilations
=
1
,
# num_piecewise_capturable_graphs_seen
num_cudagraph_caputured
=
num_cudagraph_caputured
=
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
2
,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**
kwargs
,
):
):
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
True
))
outputs
.
append
(
run_model
(
tractable_config
,
use_compile
=
True
)
run_model
(
llama_config
,
use_inductor
=
use_inductor
,
use_compile
=
True
))
run_model
(
tractable_config
,
use_inductor
=
use_inductor
,
use_compile
=
True
)
with
compilation_counter
.
expect
(
with
compilation_counter
.
expect
(
num_graphs_seen
=
1
,
# one graph for the model
num_graphs_seen
=
1
,
# one graph for the model
...
@@ -353,13 +365,27 @@ def test_toy_llama():
...
@@ -353,13 +365,27 @@ def test_toy_llama():
),
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
),
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
):
outputs
.
append
(
outputs
.
append
(
run_model
(
llama_config
,
use_compile
=
True
,
split_attn
=
True
))
run_model
(
llama_config
,
run_model
(
tractable_config
,
use_compile
=
True
,
split_attn
=
True
)
use_inductor
=
use_inductor
,
use_compile
=
True
,
split_attn
=
True
))
run_model
(
tractable_config
,
use_inductor
=
use_inductor
,
use_compile
=
True
,
split_attn
=
True
)
for
i
in
range
(
1
,
len
(
outputs
)):
for
i
in
range
(
1
,
len
(
outputs
)):
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
assert
torch
.
allclose
(
outputs
[
0
],
outputs
[
i
])
def
test_toy_llama_inductor
():
_test_toy_llama
(
use_inductor
=
True
)
def
test_toy_no_inductor
():
_test_toy_llama
(
use_inductor
=
False
)
@
torch
.
inference_mode
@
torch
.
inference_mode
def
benchmark
():
def
benchmark
():
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
...
...
vllm/compilation/compiler_interface.py
View file @
26b4fa45
...
@@ -12,6 +12,7 @@ import torch._inductor.compile_fx
...
@@ -12,6 +12,7 @@ import torch._inductor.compile_fx
import
torch.fx
as
fx
import
torch.fx
as
fx
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.utils
import
is_torch_equal_or_newer
from
vllm.utils
import
is_torch_equal_or_newer
...
@@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
...
@@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
runtime_shape
:
Optional
[
int
]
=
None
,
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
compilation_counter
.
num_inductor_compiles
+=
1
current_config
=
{}
current_config
=
{}
if
compiler_config
is
not
None
:
if
compiler_config
is
not
None
:
current_config
.
update
(
compiler_config
)
current_config
.
update
(
compiler_config
)
...
@@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface):
...
@@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface):
runtime_shape
:
Optional
[
int
]
=
None
,
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
compilation_counter
.
num_inductor_compiles
+=
1
from
torch._inductor.compile_fx
import
compile_fx
from
torch._inductor.compile_fx
import
compile_fx
current_config
=
{}
current_config
=
{}
if
compiler_config
is
not
None
:
if
compiler_config
is
not
None
:
...
@@ -528,6 +531,7 @@ class EagerAdaptor(CompilerInterface):
...
@@ -528,6 +531,7 @@ class EagerAdaptor(CompilerInterface):
runtime_shape
:
Optional
[
int
]
=
None
,
runtime_shape
:
Optional
[
int
]
=
None
,
key
:
Optional
[
str
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
)
->
tuple
[
Optional
[
Callable
],
Optional
[
Any
]]:
compilation_counter
.
num_eager_compiles
+=
1
# we don't need to compile the graph, just return the graph itself.
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
# It does not support caching, return None for the handle.
return
graph
,
None
return
graph
,
None
vllm/compilation/counter.py
View file @
26b4fa45
...
@@ -15,6 +15,10 @@ class CompilationCounter:
...
@@ -15,6 +15,10 @@ class CompilationCounter:
num_piecewise_capturable_graphs_seen
:
int
=
0
num_piecewise_capturable_graphs_seen
:
int
=
0
num_backend_compilations
:
int
=
0
num_backend_compilations
:
int
=
0
num_cudagraph_caputured
:
int
=
0
num_cudagraph_caputured
:
int
=
0
# InductorAdapter.compile calls
num_inductor_compiles
:
int
=
0
# EagerAdapter.compile calls
num_eager_compiles
:
int
=
0
def
clone
(
self
)
->
"CompilationCounter"
:
def
clone
(
self
)
->
"CompilationCounter"
:
return
copy
.
deepcopy
(
self
)
return
copy
.
deepcopy
(
self
)
...
...
vllm/config.py
View file @
26b4fa45
...
@@ -4315,15 +4315,10 @@ class VllmConfig:
...
@@ -4315,15 +4315,10 @@ class VllmConfig:
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
self
.
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
if
envs
.
VLLM_USE_V1
and
self
.
model_config
is
not
None
and
\
if
envs
.
VLLM_USE_V1
and
self
.
model_config
is
not
None
and
\
not
self
.
model_config
.
enforce_eager
:
not
self
.
model_config
.
enforce_eager
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
# FIXME(rob): Add function to set all of these.
# FIXME(rob): Add function to set all of these.
if
not
self
.
compilation_config
.
custom_ops
:
if
not
self
.
compilation_config
.
custom_ops
:
self
.
compilation_config
.
custom_ops
=
[
"none"
]
self
.
compilation_config
.
custom_ops
=
[
"none"
]
self
.
compilation_config
.
use_cudagraph
=
True
self
.
compilation_config
.
use_cudagraph
=
True
self
.
compilation_config
.
use_inductor
=
True
self
.
compilation_config
.
cudagraph_num_of_warmups
=
1
self
.
compilation_config
.
cudagraph_num_of_warmups
=
1
self
.
compilation_config
.
pass_config
.
enable_fusion
=
False
self
.
compilation_config
.
pass_config
.
enable_fusion
=
False
self
.
compilation_config
.
pass_config
.
enable_noop
=
False
self
.
compilation_config
.
pass_config
.
enable_noop
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment