Unverified Commit 0c9f01a8 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Speed up TopKLogitsWarper and TopPLogitsWarper (pytorch) (#9557)

* make TopKLogitsWarper faster

* make TopPLogitsWarper faster
parent 27d0e01d
...@@ -20,7 +20,6 @@ from typing import Callable, Iterable, List ...@@ -20,7 +20,6 @@ from typing import Callable, Iterable, List
import numpy as np import numpy as np
import torch import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True) sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p sorted_indices_to_remove = cumulative_probs > self.top_p
...@@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper):
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores
...@@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper):
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment