Commit 48a05026 authored by prajjwal1's avatar prajjwal1 Committed by Julien Chaumond
Browse files

removed deprecared use of Variable api from pplm example

parent 12d0eb5f
......@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from pplm_classification_head import ClassificationHead
......@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
if torch.cuda.is_available() and device == "cuda":
x = x.cuda()
elif device != "cuda":
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def top_k_filter(logits, k, probs=False):
"""
Masks everything but the k top entries as -infinity (1e10).
......@@ -156,9 +147,7 @@ def perturb_past(
new_accumulated_hidden = None
for i in range(num_iterations):
print("Iteration ", i + 1)
curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
]
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
......@@ -247,7 +236,7 @@ def perturb_past(
past = new_past
# apply the accumulated perturbations to the past
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator]
grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
pert_past = list(map(add, past, grad_accumulator))
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
......@@ -266,7 +255,7 @@ def get_classifier(
elif "path" in params:
resolved_archive_file = params["path"]
else:
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters")
raise ValueError("Either url or path have to be specified in the discriminator model parameters")
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
classifier.eval()
......@@ -569,9 +558,9 @@ def generate_text_pplm(
def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None:
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified")
raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
if discrim_meta is None:
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified")
raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
with open(discrim_meta, "r") as discrim_meta_file:
meta = json.load(discrim_meta_file)
......@@ -619,7 +608,7 @@ def run_pplm_example(
if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model))
print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
......@@ -706,7 +695,7 @@ def run_pplm_example(
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids:
pert_gen_text += "{}{}{}".format(
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
)
else:
pert_gen_text += tokenizer.decode([word_id])
......@@ -744,9 +733,11 @@ if __name__ == "__main__":
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
help=(
"Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;"
),
)
parser.add_argument(
"--discrim",
......@@ -756,9 +747,11 @@ if __name__ == "__main__":
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
parser.add_argument(
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator"
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
)
parser.add_argument(
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
)
parser.add_argument(
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
......@@ -774,7 +767,7 @@ if __name__ == "__main__":
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; " "0 corresponds to infinite window length",
help="Length of past which is being optimized; 0 corresponds to infinite window length",
)
parser.add_argument(
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment