Unverified Commit 0c9f01a8 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Speed up TopKLogitsWarper and TopPLogitsWarper (pytorch) (#9557)

* make TopKLogitsWarper faster

* make TopPLogitsWarper faster
parent 27d0e01d
......@@ -20,7 +20,6 @@ from typing import Callable, Iterable, List
import numpy as np
import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings
......@@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
......@@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper):
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
......@@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper):
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment