Unverified Commit e42c634a authored by 盏一's avatar 盏一 Committed by GitHub
Browse files

[Core] simplify logits resort in _apply_top_k_top_p (#8619)

parent 9cc373f3
......@@ -433,12 +433,9 @@ def _apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1],
device=logits_idx.device).expand_as(logits_idx)
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
index=logits_idx,
src=src)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
src=logits_sort)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment