Unverified Commit f150107e authored by Andrew Barnes's avatar Andrew Barnes Committed by GitHub
Browse files

[ROCm] Fix cu_seqlens_q off-by-one in AITER FA speculative decode path (#39120)


Signed-off-by: default avatarBortlesboat <bortstheboat@gmail.com>
parent d1135a50
......@@ -1181,7 +1181,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
num_decodes,
key_cache.shape[2],
)
unified_attention(
......@@ -1189,7 +1189,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
cu_seqlens_q=attn_metadata.query_start_loc[: num_decodes + 1],
max_seqlen_q=decode_max_query_len,
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment