Commit a3085020 authored by IWillPull's avatar IWillPull Committed by Julien Chaumond
Browse files

Added repetition penalty to PPLM example (#2436)

* Added repetition penalty

* Default PPLM repetition_penalty to neutral

* Minor modifications to comply with reviewer's suggestions. (j -> token_idx)

* Formatted code with `make style`
parent e83d9f1c
......@@ -344,6 +344,7 @@ def full_text_generation(
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
repetition_penalty=1.0,
**kwargs
):
classifier, class_id = get_classifier(discrim, class_label, device)
......@@ -368,7 +369,14 @@ def full_text_generation(
raise Exception("Specify either a bag of words or a discriminator")
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
model=model,
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False,
repetition_penalty=repetition_penalty,
)
if device == "cuda":
torch.cuda.empty_cache()
......@@ -401,6 +409,7 @@ def full_text_generation(
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
repetition_penalty=repetition_penalty,
)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None:
......@@ -437,6 +446,7 @@ def generate_text_pplm(
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
repetition_penalty=1.0,
):
output_so_far = None
if context:
......@@ -508,6 +518,13 @@ def generate_text_pplm(
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
for token_idx in set(output_so_far[0].tolist()):
if pert_logits[0, token_idx] < 0:
pert_logits[0, token_idx] *= repetition_penalty
else:
pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1)
if classifier is not None:
......@@ -588,6 +605,7 @@ def run_pplm_example(
seed=0,
no_cuda=False,
colorama=False,
repetition_penalty=1.0,
):
# set Random seed
torch.manual_seed(seed)
......@@ -655,6 +673,7 @@ def run_pplm_example(
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
repetition_penalty=repetition_penalty,
)
# untokenize unperturbed text
......@@ -767,6 +786,9 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", help="colors keywords")
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
)
args = parser.parse_args()
run_pplm_example(**vars(args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment