Unverified Commit 4d042ed8 authored by Santino Ramos's avatar Santino Ramos Committed by GitHub
Browse files

[Bugfix] Fix tensor shape mismatch in sparse attention with speculative decoding (#39542)


Signed-off-by: default avatarSantino Ramos <santinor@inferact.ai>
parent 10d9872d
......@@ -213,7 +213,7 @@ def sparse_attn_indexer(
)
torch.ops._C.persistent_topk(
logits,
decode_metadata.seq_lens,
seq_lens,
topk_indices,
topk_workspace,
topk_tokens,
......@@ -250,7 +250,7 @@ def sparse_attn_indexer(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices_buffer[: topk_indices.shape[0], : topk_indices.shape[-1]] = (
topk_indices
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment