Commit ecd15667 authored by leo-du's avatar leo-du Committed by Lysandre Debut
Browse files

fix repetition penalty

parent c5441946
...@@ -139,7 +139,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -139,7 +139,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.) next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858) # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated): for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment