Unverified Commit a98173cc authored by LSinev's avatar LSinev Committed by GitHub
Browse files

make RepetitionPenaltyLogitsProcessor faster (#9600)

parent a1ad16a4
...@@ -155,13 +155,12 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -155,13 +155,12 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
ranges = torch.arange(scores.shape[0]) score = torch.gather(scores, 1, input_ids)
score = scores[ranges[:, None], input_ids]
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty) score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores[ranges[:, None], input_ids] = score scores.scatter_(1, input_ids, score)
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment