Unverified Commit 4af9ed21 authored by zhao, zhenhui's avatar zhao, zhenhui Committed by GitHub
Browse files

[Bugfix](xpu): prevent “selected index k out of range” in TP decode path (#37259)


Signed-off-by: default avatarzhenzhao <zhenzhao@habana.ai>
parent 9c7cab5e
...@@ -426,7 +426,8 @@ class xpu_ops: ...@@ -426,7 +426,8 @@ class xpu_ops:
mask = positions <= index_end_pos mask = positions <= index_end_pos
# mask: [B * N, L] # mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf")) logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k # ensure we don't set indices for the top k
# that is out of range(masked already) # that is out of range(masked already)
# this will happen if context length is shorter than K # this will happen if context length is shorter than K
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment