Commit 48a05026 authored by prajjwal1's avatar prajjwal1 Committed by Julien Chaumond
Browse files

removed deprecared use of Variable api from pplm example

parent 12d0eb5f
...@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union ...@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange from tqdm import trange
from pplm_classification_head import ClassificationHead from pplm_classification_head import ClassificationHead
...@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
} }
def to_var(x, requires_grad=False, volatile=False, device="cuda"):
if torch.cuda.is_available() and device == "cuda":
x = x.cuda()
elif device != "cuda":
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def top_k_filter(logits, k, probs=False): def top_k_filter(logits, k, probs=False):
""" """
Masks everything but the k top entries as -infinity (1e10). Masks everything but the k top entries as -infinity (1e10).
...@@ -156,9 +147,7 @@ def perturb_past( ...@@ -156,9 +147,7 @@ def perturb_past(
new_accumulated_hidden = None new_accumulated_hidden = None
for i in range(num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
curr_perturbation = [ curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator
]
# Compute hidden using perturbed past # Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation)) perturbed_past = list(map(add, past, curr_perturbation))
...@@ -247,7 +236,7 @@ def perturb_past( ...@@ -247,7 +236,7 @@ def perturb_past(
past = new_past past = new_past
# apply the accumulated perturbations to the past # apply the accumulated perturbations to the past
grad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator] grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
pert_past = list(map(add, past, grad_accumulator)) pert_past = list(map(add, past, grad_accumulator))
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
...@@ -266,7 +255,7 @@ def get_classifier( ...@@ -266,7 +255,7 @@ def get_classifier(
elif "path" in params: elif "path" in params:
resolved_archive_file = params["path"] resolved_archive_file = params["path"]
else: else:
raise ValueError("Either url or path have to be specified " "in the discriminator model parameters") raise ValueError("Either url or path have to be specified in the discriminator model parameters")
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
classifier.eval() classifier.eval()
...@@ -569,9 +558,9 @@ def generate_text_pplm( ...@@ -569,9 +558,9 @@ def generate_text_pplm(
def set_generic_model_params(discrim_weights, discrim_meta): def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None: if discrim_weights is None:
raise ValueError("When using a generic discriminator, " "discrim_weights need to be specified") raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
if discrim_meta is None: if discrim_meta is None:
raise ValueError("When using a generic discriminator, " "discrim_meta need to be specified") raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
with open(discrim_meta, "r") as discrim_meta_file: with open(discrim_meta, "r") as discrim_meta_file:
meta = json.load(discrim_meta_file) meta = json.load(discrim_meta_file)
...@@ -619,7 +608,7 @@ def run_pplm_example( ...@@ -619,7 +608,7 @@ def run_pplm_example(
if discrim is not None: if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"] pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model)) print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model # load pretrained model
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True) model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
...@@ -706,7 +695,7 @@ def run_pplm_example( ...@@ -706,7 +695,7 @@ def run_pplm_example(
for word_id in pert_gen_tok_text.tolist()[0]: for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids: if word_id in bow_word_ids:
pert_gen_text += "{}{}{}".format( pert_gen_text += "{}{}{}".format(
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
) )
else: else:
pert_gen_text += tokenizer.decode([word_id]) pert_gen_text += tokenizer.decode([word_id])
...@@ -744,9 +733,11 @@ if __name__ == "__main__": ...@@ -744,9 +733,11 @@ if __name__ == "__main__":
"-B", "-B",
type=str, type=str,
default=None, default=None,
help="Bags of words used for PPLM-BoW. " help=(
"Either a BOW id (see list in code) or a filepath. " "Bags of words used for PPLM-BoW. "
"Multiple BoWs separated by ;", "Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;"
),
) )
parser.add_argument( parser.add_argument(
"--discrim", "--discrim",
...@@ -756,9 +747,11 @@ if __name__ == "__main__": ...@@ -756,9 +747,11 @@ if __name__ == "__main__":
choices=("clickbait", "sentiment", "toxicity", "generic"), choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use", help="Discriminator to use",
) )
parser.add_argument("--discrim_weights", type=str, default=None, help="Weights for the generic discriminator")
parser.add_argument( parser.add_argument(
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator" "--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
)
parser.add_argument(
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
) )
parser.add_argument( parser.add_argument(
"--class_label", type=int, default=-1, help="Class label used for the discriminator", "--class_label", type=int, default=-1, help="Class label used for the discriminator",
...@@ -774,7 +767,7 @@ if __name__ == "__main__": ...@@ -774,7 +767,7 @@ if __name__ == "__main__":
"--window_length", "--window_length",
type=int, type=int,
default=0, default=0,
help="Length of past which is being optimized; " "0 corresponds to infinite window length", help="Length of past which is being optimized; 0 corresponds to infinite window length",
) )
parser.add_argument( parser.add_argument(
"--horizon_length", type=int, default=1, help="Length of future to optimize over", "--horizon_length", type=int, default=1, help="Length of future to optimize over",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment