Unverified Commit 9ea1f598 authored by Mengtao (Martin) Yuan's avatar Mengtao (Martin) Yuan Committed by GitHub
Browse files

Use paged_attention_v1 for sliding window decode in rocm_aiter_fa (#34378)


Signed-off-by: default avatarMartin Yuan <myuan@meta.com>
Co-authored-by: default avatarMartin Yuan <myuan@meta.com>
parent f120bd42
......@@ -1075,35 +1075,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Sliding window with shuffle layout is not supported yet."
)
from aiter.ops.triton.unified_attention import (
unified_attention,
)
descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
max_seqlen_q=1, # optimize this
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table[:num_decodes],
softcap=self.logits_soft_cap,
q_descale=None,
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return
assert attn_metadata.decode_metadata is not None
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
......@@ -1172,6 +1143,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
1,
self.sliding_window[0] + 1,
)
else:
raise NotImplementedError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment