Unverified Commit 29bdb883 authored by Binoy Dalal's avatar Binoy Dalal Committed by GitHub
Browse files

Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598)

* refactored exisiting nested loops to vectorized implementation

* replaced explicit indexing with torch.where

* modifying score for previous input_ids only
parent 2594bd8b
...@@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]): ranges = torch.arange(scores.shape[0])
for previous_token in set(input_ids[i].tolist()): score = scores[ranges[:, None], input_ids]
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0: # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
scores[i, previous_token] *= self.penalty score = torch.where(score < 0, score * self.penalty, score / self.penalty)
else:
scores[i, previous_token] /= self.penalty scores[ranges[:, None], input_ids] = score
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment