Unverified Commit 1a971808 authored by Mengtao (Martin) Yuan's avatar Mengtao (Martin) Yuan Committed by GitHub
Browse files

Fix CUDA graph decode capture crash in AITER FlashAttention (#36042)


Signed-off-by: default avatarMartin Yuan <myuan@meta.com>
Co-authored-by: default avatarMartin Yuan <myuan@meta.com>
parent 7eb524e6
......@@ -1152,11 +1152,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
decode_max_query_len = attn_metadata.decode_metadata.max_query_len
# Use unified_attention for speculative decoding (multi-token)
# or when sliding window is enabled
if self.sliding_window[0] != -1 or decode_max_query_len > 1:
if decode_max_query_len > 1:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Shuffle KV cache layout is not supported with sliding "
"window or speculative decoding (multi-token decode)."
"Shuffle KV cache layout is not supported with "
"speculative decoding (multi-token decode)."
)
from aiter.ops.triton.unified_attention import (
unified_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment