Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8acb4bad
Unverified
Commit
8acb4bad
authored
Jul 01, 2025
by
Woosuk Kwon
Committed by
GitHub
Jul 01, 2025
Browse files
[CUDA graphs] Enable full cuda graphs with FA3 AoT scheduling (#20301)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
314af861
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
7 deletions
+54
-7
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+53
-6
No files found.
cmake/external_projects/vllm_flash_attn.cmake
View file @
8acb4bad
...
@@ -38,7 +38,7 @@ else()
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
FetchContent_Declare
(
vllm-flash-attn
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
5f3644181c7a15345ce20bfc65af117d3601b52
4
GIT_TAG
1c2624e53c078854e0637ee566c72fe2107e75f
4
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
8acb4bad
...
@@ -36,6 +36,9 @@ if TYPE_CHECKING:
...
@@ -36,6 +36,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
=
16
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -114,6 +117,7 @@ class FlashAttentionMetadata:
...
@@ -114,6 +117,7 @@ class FlashAttentionMetadata:
# Optional aot scheduling
# Optional aot scheduling
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
max_num_splits
:
int
=
0
# for local attention
# for local attention
@
dataclass
@
dataclass
...
@@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
...
@@ -158,15 +162,35 @@ class FlashAttentionMetadataBuilder(
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
self
.
block_table
=
block_table
self
.
max_num_splits
=
0
# No upper bound on the number of splits.
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
if
self
.
use_full_cuda_graph
:
if
self
.
use_full_cuda_graph
:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
if
not
self
.
aot_schedule
:
# yet. This is because the scheduler and kernel need to always use
raise
ValueError
(
# the same num_splits (which acts as an upper bound with the
"AoT scheduling is required for full cuda graph."
)
# dynamic split scheduler) which is currently heuristically decided
capture_sizes
=
compilation_config
.
cudagraph_capture_sizes
# by the kernel launching code.
if
not
capture_sizes
:
self
.
aot_schedule
=
False
raise
ValueError
(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True."
)
self
.
max_cudagraph_size
=
max
(
capture_sizes
)
if
self
.
max_cudagraph_size
>
992
:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise
ValueError
(
"Capture size larger than 992 is not supported for "
"full cuda graph."
)
self
.
scheduler_metadata
=
torch
.
zeros
(
self
.
runner
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self
.
max_num_splits
=
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
# Sliding window size to be used with the AOT scheduler will be
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
# populated on first build() call.
...
@@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -226,6 +250,7 @@ class FlashAttentionMetadataBuilder(
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_q
=
cu_query_lens
,
causal
=
causal
,
causal
=
causal
,
window_size
=
self
.
aot_sliding_window
,
window_size
=
self
.
aot_sliding_window
,
num_splits
=
self
.
max_num_splits
,
)
)
return
None
return
None
...
@@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
...
@@ -302,6 +327,26 @@ class FlashAttentionMetadataBuilder(
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
causal
=
True
)
if
self
.
use_full_cuda_graph
:
assert
scheduler_metadata
is
not
None
n
=
scheduler_metadata
.
shape
[
0
]
self
.
scheduler_metadata
[:
n
]
=
scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self
.
scheduler_metadata
[
n
:]
=
0
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
max_num_splits
=
0
if
(
self
.
use_full_cuda_graph
and
num_actual_tokens
<=
self
.
max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits
=
self
.
max_num_splits
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
...
@@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -318,6 +363,7 @@ class FlashAttentionMetadataBuilder(
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
max_num_splits
=
max_num_splits
,
)
)
return
attn_metadata
return
attn_metadata
...
@@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -510,6 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
num_splits
=
attn_metadata
.
max_num_splits
,
)
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment