Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71b02b7a
Commit
71b02b7a
authored
Aug 08, 2025
by
zhuwenwen
Browse files
update fa full_cuda_graph support
parent
ffe9e7db
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
8 deletions
+9
-8
vllm/config.py
vllm/config.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+7
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
vllm/config.py
View file @
71b02b7a
...
...
@@ -4135,7 +4135,7 @@ class CompilationConfig:
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph
:
bool
=
Fals
e
full_cuda_graph
:
bool
=
Tru
e
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
71b02b7a
...
...
@@ -151,7 +151,7 @@ def _get_sliding_window_configs(
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
or
current_platform
.
is_rocm
()
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
@@ -172,7 +172,8 @@ class FlashAttentionMetadataBuilder(
self
.
max_num_splits
=
0
# No upper bound on the number of splits.
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
use_full_cuda_graph
=
self
.
compilation_config
.
full_cuda_graph
if
not
current_platform
.
is_rocm
()
and
self
.
use_full_cuda_graph
:
if
self
.
use_full_cuda_graph
:
if
not
current_platform
.
is_rocm
():
if
not
self
.
aot_schedule
:
raise
ValueError
(
"AoT scheduling is required for full cuda graph."
)
...
...
@@ -325,7 +326,7 @@ class FlashAttentionMetadataBuilder(
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
max_num_splits
=
0
if
(
not
current_platform
.
is_rocm
()
and
self
.
use_full_cuda_graph
if
(
self
.
use_full_cuda_graph
and
num_actual_tokens
<=
self
.
max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
71b02b7a
...
...
@@ -2548,7 +2548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
device
,
)
if
(
not
current_platform
.
is_rocm
()
and
self
.
full_cuda_graph
if
(
self
.
full_cuda_graph
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
raise
ValueError
(
f
"Full CUDAGraph not supported for "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment