Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f9e7354
Unverified
Commit
0f9e7354
authored
Jun 25, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jun 25, 2025
Browse files
[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
ba7ba35c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
18 deletions
+7
-18
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+7
-18
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
0f9e7354
...
@@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
...
@@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
if
self
.
use_full_cuda_graph
and
not
self
.
aot_schedule
:
if
self
.
use_full_cuda_graph
:
raise
ValueError
(
"Full CUDA graph mode requires AOT scheduling, "
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
"which requires FlashAttention 3."
)
# yet. This is because the scheduler and kernel need to always use
self
.
scheduler_metadata
=
torch
.
zeros
(
self
.
runner
.
max_num_reqs
+
1
,
# the same num_splits (which acts as an upper bound with the
dtype
=
torch
.
int32
,
# dynamic split scheduler) which is currently heuristically decided
device
=
self
.
runner
.
device
)
# by the kernel launching code.
self
.
aot_schedule
=
False
# Sliding window size to be used with the AOT scheduler will be
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
# populated on first build() call.
...
@@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
causal
=
True
)
if
self
.
use_full_cuda_graph
:
assert
scheduler_metadata
is
not
None
n
=
scheduler_metadata
.
shape
[
0
]
self
.
scheduler_metadata
[:
n
].
copy_
(
scheduler_metadata
,
non_blocking
=
True
)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self
.
scheduler_metadata
[
n
:]
=
0
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment