"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "72f9235a491e7800b3a7686e4901729d371dabed"
Unverified Commit aeef4823 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2303 from patrickvonplaten/fix_error_with_repetition_penalty

fix repetition penalty error in modeling_utils.py
parents 0412f3d9 18e5bdbe
...@@ -728,7 +728,11 @@ class PreTrainedModel(nn.Module): ...@@ -728,7 +728,11 @@ class PreTrainedModel(nn.Module):
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size): for i in range(batch_size):
for previous_tokens in set(input_ids[i].tolist()): for previous_tokens in set(input_ids[i].tolist()):
next_token_logits[i, previous_tokens] /= repetition_penalty # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if next_token_logits[i, previous_tokens] < 0:
next_token_logits[i, previous_tokens] *= repetition_penalty
else:
next_token_logits[i, previous_tokens] /= repetition_penalty
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
...@@ -807,7 +811,11 @@ class PreTrainedModel(nn.Module): ...@@ -807,7 +811,11 @@ class PreTrainedModel(nn.Module):
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size * num_beams): for i in range(batch_size * num_beams):
for previous_tokens in set(input_ids[i].tolist()): for previous_tokens in set(input_ids[i].tolist()):
scores[i, previous_tokens] /= repetition_penalty # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if scores[i, previous_tokens] < 0:
scores[i, previous_tokens] *= repetition_penalty
else:
scores[i, previous_tokens] /= repetition_penalty
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment