"tests/vscode:/vscode.git/clone" did not exist on "8a3cd90af534c39425ebfdfd295eea0a4582d541"
Unverified Commit 4af9ed21 authored by zhao, zhenhui's avatar zhao, zhenhui Committed by GitHub
Browse files

[Bugfix](xpu): prevent “selected index k out of range” in TP decode path (#37259)


Signed-off-by: default avatarzhenzhao <zhenzhao@habana.ai>
parent 9c7cab5e
......@@ -426,7 +426,8 @@ class xpu_ops:
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment