Unverified Commit 29bdb883 authored by Binoy Dalal's avatar Binoy Dalal Committed by GitHub
Browse files

Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598)

* refactored exisiting nested loops to vectorized implementation

* replaced explicit indexing with torch.where

* modifying score for previous input_ids only
parent 2594bd8b
......@@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= self.penalty
else:
scores[i, previous_token] /= self.penalty
ranges = torch.arange(scores.shape[0])
score = scores[ranges[:, None], input_ids]
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores[ranges[:, None], input_ids] = score
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment