Unverified Commit e42c634a authored by 盏一's avatar 盏一 Committed by GitHub
Browse files

[Core] simplify logits resort in _apply_top_k_top_p (#8619)

parent 9cc373f3
...@@ -433,12 +433,9 @@ def _apply_top_k_top_p( ...@@ -433,12 +433,9 @@ def _apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf")) logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities. # Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1], logits = torch.empty_like(logits_sort).scatter_(dim=-1,
device=logits_idx.device).expand_as(logits_idx)
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
index=logits_idx, index=logits_idx,
src=src) src=logits_sort)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment