Unverified Commit 0f9e7354 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent ba7ba35c
......@@ -158,12 +158,13 @@ class FlashAttentionMetadataBuilder(
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)
if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
# yet. This is because the scheduler and kernel need to always use
# the same num_splits (which acts as an upper bound with the
# dynamic split scheduler) which is currently heuristically decided
# by the kernel launching code.
self.aot_schedule = False
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
......@@ -299,18 +300,6 @@ class FlashAttentionMetadataBuilder(
max_seq_len=max_seq_len,
causal=True)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment