Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2338daff
Unverified
Commit
2338daff
authored
Sep 24, 2025
by
Lucas Wilkinson
Committed by
GitHub
Sep 24, 2025
Browse files
[BugFix] Potential Fix for FA3 full-cudagraph IMA (#25490)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
2e19a848
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+11
-11
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
2338daff
...
...
@@ -194,10 +194,9 @@ class FlashAttentionMetadataBuilder(
self
.
use_full_cuda_graph
=
\
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
if
self
.
use_full_cuda_graph
and
self
.
aot_schedule
:
self
.
max_cudagraph_size
=
self
.
compilation_config
.
max_capture_size
if
self
.
use_full_cuda_graph
and
self
.
aot_schedule
:
if
self
.
max_cudagraph_size
>
992
:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
...
...
@@ -259,6 +258,15 @@ class FlashAttentionMetadataBuilder(
self
.
aot_schedule
=
False
aot_schedule
=
False
max_num_splits
=
0
# 0 means use FA3's heuristics, not CG compatible
if
self
.
use_full_cuda_graph
and
\
num_actual_tokens
<=
self
.
max_cudagraph_size
:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits
=
self
.
max_num_splits
def
schedule
(
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
cache_dtype
=
self
.
cache_config
.
cache_dtype
...
...
@@ -281,7 +289,7 @@ class FlashAttentionMetadataBuilder(
page_size
=
self
.
block_size
,
causal
=
causal
,
window_size
=
self
.
aot_sliding_window
,
num_splits
=
self
.
max_num_splits
,
num_splits
=
max_num_splits
,
)
return
None
...
...
@@ -322,7 +330,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len
=
max_seq_len
,
causal
=
causal
)
# For FA3 + full cudagraph
max_num_splits
=
0
if
self
.
use_full_cuda_graph
and
scheduler_metadata
is
not
None
:
n
=
scheduler_metadata
.
shape
[
0
]
self
.
scheduler_metadata
[:
n
]
=
scheduler_metadata
...
...
@@ -333,13 +340,6 @@ class FlashAttentionMetadataBuilder(
self
.
scheduler_metadata
[
n
:]
=
0
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
if
num_actual_tokens
<=
self
.
max_cudagraph_size
:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits
=
self
.
max_num_splits
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment