Unverified Commit e64b39ea authored by Andrew Barnes's avatar Andrew Barnes Committed by GitHub
Browse files

[ROCm] Align AiterFlashAttentionImpl attn_type check with backend (#39119)


Signed-off-by: default avatarBortlesboat <bortstheboat@gmail.com>
parent 2faad083
...@@ -844,9 +844,13 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -844,9 +844,13 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: if attn_type != AttentionType.DECODER:
raise NotImplementedError( raise NotImplementedError(
"Encoder self-attention is not implemented for AiterFlashAttentionImpl" "Only decoder self-attention is supported for "
"AiterFlashAttentionImpl. ENCODER_DECODER is not supported "
"because the prefill path uses cu_seqlens_k set to decoder "
"query_start_loc with causal=True, which is incorrect for "
"cross-attention."
) )
def extend_for_sliding_window( def extend_for_sliding_window(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment