"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "36ea8784a875bde21c88f84dfb99475b6e8187e8"
Commit 34a83faa authored by Piero Molino's avatar Piero Molino Committed by Julien Chaumond
Browse files

Let's make PPLM great again

parent d5faa74c
#! /usr/bin/env python3
# coding=utf-8 # coding=utf-8
# Copyright 2018 The Uber AI Team Authors. # Copyright 2018 The Uber AI Team Authors.
# #
...@@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer ...@@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel from transformers.modeling_gpt2 import GPT2LMHeadModel
PPLM_BOW = 1 PPLM_BOW = 1
PPLM_DISCRIM = 2 PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3 PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15 SMALL_CONST = 1e-15
SmallConst = 1e-15
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium") TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = { BAG_OF_WORDS_ARCHIVE_MAP = {
...@@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"default_class": 1, "default_class": 1,
}, },
"sentiment": { "sentiment": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/sentiment_classifierhead.pt", "url": "http://s.yosinski.com/SST_classifier_head.pt",
"class_size": 5, "class_size": 5,
"embed_size": 1024, "embed_size": 1024,
"class_vocab": {"very_positive": 2, "very_negative": 3}, "class_vocab": {"very_positive": 2, "very_negative": 3},
...@@ -81,6 +84,30 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -81,6 +84,30 @@ DISCRIMINATOR_MODELS_PARAMS = {
} }
def to_var(x, requires_grad=False, volatile=False):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def top_k_filter(logits, k, probs=False):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0:
return logits
else:
values = torch.topk(logits, k)[0]
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
if probs:
return torch.where(logits < batch_mins,
torch.ones_like(logits) * 0.0, logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10,
logits)
class ClassificationHead(torch.nn.Module): class ClassificationHead(torch.nn.Module):
""" Classification Head for the transformer """ """ Classification Head for the transformer """
...@@ -99,234 +126,175 @@ class ClassificationHead(torch.nn.Module): ...@@ -99,234 +126,175 @@ class ClassificationHead(torch.nn.Module):
return logits return logits
def to_var(x, requires_grad=False, volatile=False): def perturb_past(past, model, prev, args, classifier, good_index=None,
if torch.cuda.is_available(): stepsize=0.01, vocab_size=50257,
x = x.cuda() original_probs=None, accumulated_hidden=None, true_past=None,
return Variable(x, requires_grad=requires_grad, volatile=volatile) grad_norms=None):
window_length = args.window_length
gm_scale, kl_scale = args.gm_scale, args.kl_scale
def top_k_filter(logits, k, probs=False): one_hot_vectors = []
""" for good_list in good_index:
Masks everything but the k top entries as -infinity (1e10). good_list = list(filter(lambda x: len(x) <= 1, good_list))
Used to mask logits such that e^-infinity -> 0 won't contribute to the good_list = torch.tensor(good_list).cuda()
sum of the denominator. num_good = good_list.shape[0]
""" one_hot_good = torch.zeros(num_good, vocab_size).cuda()
if k <= 0: one_hot_good.scatter_(1, good_list, 1)
return logits one_hot_vectors.append(one_hot_good)
else: # Generate inital perturbed past
values = torch.topk(logits, k)[0] past_perturb_orig = [
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) (np.random.uniform(0.0, 0.0, p.shape).astype('float32'))
for p in past]
if probs:
return torch.where(
logits < batch_mins,
torch.ones_like(logits) * 0.0,
logits
)
return torch.where(
logits < batch_mins,
torch.ones_like(logits) * -1e10,
logits
)
def perturb_past(
past,
model,
last,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
classifier=None,
label_class=None,
one_hot_bows_vectors=None,
loss_type=0,
num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
# initializie perturbation accumulator
grad_accumulator = [
(np.zeros(p.shape).astype("float32"))
for p in past
]
if accumulated_hidden is None: if accumulated_hidden is None:
accumulated_hidden = 0 accumulated_hidden = 0
if decay: if args.decay:
decay_mask = torch.arange( decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
0.0, 1:]
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else: else:
decay_mask = 1.0 decay_mask = 1.0
# TODO fix this comment (SUMANTH) # Generate a mask is gradient perturbated is based on a past window
# generate a mask if perturbated gradient is based on a past window _, _, _, current_length, _ = past[0].shape
_, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0: if current_length > window_length and window_length > 0:
ones_key_val_shape = ( ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
tuple(past[0].shape[:-2]) [window_length]) + tuple(
+ tuple([window_length]) past[0].shape[-1:])
+ tuple(past[0].shape[-1:])
) zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
[current_length - window_length]) + tuple(
zeros_key_val_shape = ( past[0].shape[-1:])
tuple(past[0].shape[:-2])
+ tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape) ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat( window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)),
(ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).cuda()
dim=-2
).cuda()
else: else:
window_mask = torch.ones_like(past[0]).cuda() window_mask = torch.ones_like(past[0]).cuda()
# accumulate perturbations for num_iterations
loss_per_iter = [] loss_per_iter = []
for i in range(num_iterations): for i in range(args.num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
curr_perturbation = [ perturbed_past = list(map(add, past, past_perturb))
to_var(torch.from_numpy(p_), requires_grad=True)
for p_ in grad_accumulator
]
# Compute hidden using perturbed past _, _, _, current_length, _ = past_perturb[0].shape
curr_pert_past = list(map(add, past, curr_perturbation))
all_logits, _, all_hidden = model(last, past=curr_pert_past)
hidden = all_hidden[-1]
accumulated_hidden += torch.sum(hidden, dim=1).detach()
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# compute loss # _, future_past = model(prev, past=perturbed_past)
bow_loss = 0.0 # hidden = model.hidden_states
discrim_loss = 0.0
kl_loss = 0.0
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: # Piero modified model call
for one_hot_bow in one_hot_bows_vectors: logits, _, all_hidden = model(prev, past=perturbed_past)
bow_logits = torch.mm(probs, torch.t(one_hot_bow)) hidden = all_hidden[-1]
bow_loss += -torch.log(torch.sum(bow_logits)) new_accumulated_hidden = accumulated_hidden + torch.sum(hidden,
print(" pplm_bow_loss:", bow_loss.data.cpu().numpy()) dim=1).detach()
if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM: # TODO: Check the layer-norm consistency of this with trained discriminator
logits = logits[:, -1, :]
probabs = F.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
if args.loss_type == 1 or args.loss_type == 3:
for one_hot_good in one_hot_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits
loss_word = torch.sum(loss_word)
loss_word = -torch.log(loss_word)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss += loss_word
loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if args.loss_type == 2 or args.loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
# TODO all there are for (SUMANTH) new_true_past = true_past
# TODO why we need to do this assignment and not just using unpert_past? for i in range(args.horizon_length):
curr_unpert_past = unpert_past future_probabs = F.softmax(logits, dim=-1) # Get softmax
# Get the model's token embeddings in order to compute our own embeds from curr_probs: future_probabs = torch.unsqueeze(future_probabs, dim=1)
wte = model.resize_token_embeddings()
# TODO i is never used, why do we need to do this i times instead multiplying # _, new_true_past = model(future_probabs, past=new_true_past)
# torch.sum(unpert_hidden, dim=1) * horizon_length? # future_hidden = model.hidden_states # Get expected hidden states
for i in range(horizon_length):
# TODO the next two lines can be done only one time, and why not using probs instead as they do not change at each iteration? # Piero modified model call
curr_probs = F.softmax(logits, dim=-1) # get softmax wte = model.resize_token_embeddings()
curr_probs = torch.unsqueeze(curr_probs, dim=1) inputs_embeds = torch.matmul(future_probabs, wte.weight.data)
inputs_embeds = torch.matmul(curr_probs, wte.weight.data) _, new_true_past, future_hidden = model(
_, curr_unpert_past, curr_all_hidden = model( past=new_true_past,
past=curr_unpert_past,
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds
) )
# get expected hidden states future_hidden = future_hidden[-1]
unpert_hidden = curr_all_hidden[-1]
accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach() new_accumulated_hidden = new_accumulated_hidden + torch.sum(
future_hidden, dim=1)
prediction = classifier( predicted_sentiment = classifier(new_accumulated_hidden / (
accumulated_hidden / (curr_length + 1 + horizon_length) current_length + 1 + args.horizon_length))
)
label = torch.tensor([label_class], device="cuda", dtype=torch.long) label = torch.tensor([args.label_class], device='cuda',
discrim_loss += ce_loss(prediction, label) dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
loss += discrim_loss
loss_list.append(discrim_loss)
if kl_scale >= 0.0: kl_loss = 0.0
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) if kl_scale > 0.0:
unpert_probs = ( p = (F.softmax(original_probs[:, -1, :], dim=-1))
unpert_probs + SMALL_CONST * p = p + SmallConst * (p <= SmallConst).type(
(unpert_probs <= SMALL_CONST).type( torch.FloatTensor).cuda().detach()
torch.FloatTensor correction = SmallConst * (probabs <= SmallConst).type(
).cuda().detach() torch.FloatTensor).cuda().detach()
) corrected_probabs = probabs + correction.detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).type(
torch.FloatTensor
).cuda().detach()
corrected_probs = probs + correction.detach()
kl_loss = kl_scale * ( kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum() (corrected_probabs * (corrected_probabs / p).log()).sum())
)
print(' kl_loss', (kl_loss).data.cpu().numpy()) print(' kl_loss', (kl_loss).data.cpu().numpy())
loss += kl_loss # + discrim_loss
loss = bow_loss + discrim_loss + kl_loss
loss_per_iter.append(loss.data.cpu().numpy()) loss_per_iter.append(loss.data.cpu().numpy())
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients
loss.backward() loss.backward()
if grad_norms is not None and args.loss_type == 1:
# calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW:
grad_norms = [ grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in enumerate(curr_perturbation) for index, p_ in
] enumerate(past_perturb)]
else: else:
grad_norms = [ grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for
(torch.norm(p_.grad * window_mask) + SMALL_CONST) index, p_ in enumerate(past_perturb)]
for index, p_ in enumerate(curr_perturbation)
]
# normalize gradients
grad = [ grad = [
-stepsize -stepsize * (p_.grad * window_mask / grad_norms[
* (p_.grad * window_mask / grad_norms[ index] ** args.gamma).data.cpu().numpy()
index] ** gamma).data.cpu().numpy() for index, p_ in enumerate(past_perturb)]
for index, p_ in enumerate(curr_perturbation) past_perturb_orig = list(map(add, grad, past_perturb_orig))
]
# accumulate gradients
grad_accumulator = list(map(add, grad, grad_accumulator))
# reset gradients, just to make sure for p_ in past_perturb:
for p_ in curr_perturbation:
p_.grad.data.zero_() p_.grad.data.zero_()
# removing past from the graph
new_past = [] new_past = []
for p_ in past: for p in past:
new_past.append(p_.detach()) new_past.append(p.detach())
past = new_past past = new_past
# apply the accumulated perturbations to the past past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
grad_accumulator = [ past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
to_var(torch.from_numpy(p_), requires_grad=True) perturbed_past = list(map(add, past, past_perturb))
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator))
return pert_past, accumulated_hidden, grad_norms, loss_per_iter return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter
def get_classifier( def get_classifier(
name: Optional[str], label_class: Union[str, int], device: Union[str, torch.device] name: Optional[str], label_class: Union[str, int],
device: Union[str, torch.device]
) -> Tuple[Optional[ClassificationHead], Optional[int]]: ) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None: if name is None:
return None, None return None, None
...@@ -337,7 +305,8 @@ def get_classifier( ...@@ -337,7 +305,8 @@ def get_classifier(
embed_size=params['embed_size'] embed_size=params['embed_size']
).to(device) ).to(device)
resolved_archive_file = cached_path(params["url"]) resolved_archive_file = cached_path(params["url"])
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device))
classifier.eval() classifier.eval()
if isinstance(label_class, str): if isinstance(label_class, str):
...@@ -364,7 +333,8 @@ def get_classifier( ...@@ -364,7 +333,8 @@ def get_classifier(
return classifier, label_id return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[List[int]]]: def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
List[List[int]]]:
bow_indices = [] bow_indices = []
for id_or_path in bag_of_words_ids_or_paths: for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
...@@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[ ...@@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[
else: else:
filepath = id_or_path filepath = id_or_path
with open(filepath, "r") as f: with open(filepath, "r") as f:
words = f.read().split("\n") words = f.read().strip().split("\n")
bow_indices.append([TOKENIZER.encode(word, add_prefix_space=True) for word in words]) bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
words])
return bow_indices return bow_indices
...@@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices): ...@@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices):
return one_hot_bows_vectors return one_hot_bows_vectors
def full_text_generation( def latent_perturb(model, args, context=None, sample=True, device='cuda'):
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
discrim, args.discrim,
label_class, args.label_class,
device device
) )
bow_indices = [] # if args.discrim == 'clickbait':
if bag_of_words: # classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) # classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
if bag_of_words and classifier: # args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def list_tokens(word_list):
token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in
word_list]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index = []
actual_words = None
if args.bag_of_words:
good_index = get_bag_of_words_indices(args.bag_of_words.split(";"))
for good_list in good_index:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list]
if args.bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
loss_type = PPLM_BOW_DISCRIM args.loss_type = PPLM_BOW_DISCRIM
elif bag_of_words: elif args.bag_of_words:
loss_type = PPLM_BOW args.loss_type = PPLM_BOW
print("Using PPLM-BoW") print("Using PPLM-BoW")
elif classifier is not None: elif classifier is not None:
loss_type = PPLM_DISCRIM args.loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim") print("Using PPLM-Discrim")
else: else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
unpert_gen_tok_text, _, _ = generate_text_pplm( original, _, _ = sample_from_hidden(model=model, args=args, context=context,
model=model, device=device,
context=context, perturb=False, good_index=good_index,
device=device, classifier=classifier)
length=length,
perturb=False
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
pert_gen_tok_texts = [] perturbed_list = []
discrim_losses = [] discrim_loss_list = []
losses_in_time = [] loss_in_time_list = []
for i in range(num_samples): for i in range(args.num_samples):
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model,
model=model, args=args,
context=context, context=context,
device=device, device=device,
sample=sample, perturb=True,
perturb=True, good_index=good_index,
bow_indices=bow_indices, classifier=classifier)
classifier=classifier, perturbed_list.append(perturbed)
label_class=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None: if classifier is not None:
discrim_losses.append(discrim_loss.data.cpu().numpy()) discrim_loss_list.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time) loss_in_time_list.append(loss_in_time)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
def generate_text_pplm(
model,
context=None,
past=None,
device="cuda",
sample=True,
perturb=True,
classifier=None,
label_class=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
if context
else None
)
# collect one hot vectors for bags of words def sample_from_hidden(model, args, classifier, context=None, past=None,
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) device='cuda',
sample=True, perturb=True, good_index=None):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None
grad_norms = None grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = [] loss_in_time = []
for i in trange(length, ascii=True): for i in trange(args.length, ascii=True):
# Get past/probs for current output, except for last word # Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current_token # Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
# run model forward to obtain unperturbed if past is None and output is not None:
if past is None and output_so_far is not None: prev = output[:, -1:]
last = output_so_far[:, -1:] # _, past = model(output[:, :-1])
if output_so_far.shape[1] > 1: # original_probs, true_past = model(output)
_, past, _ = model(output_so_far[:, :-1]) # true_hidden = model.hidden_states
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) # Piero modified model call
unpert_last_hidden = unpert_all_hidden[-1] _, past, _ = model(output[:, :-1])
original_probs, true_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
else: else:
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) # original_probs, true_past = model(output)
unpert_last_hidden = unpert_all_hidden[-1] # true_hidden = model.hidden_states
# check if we are abowe grad max length # Piero modified model call
if i >= grad_length: original_probs, true_past, unpert_all_hidden = model(output)
current_stepsize = stepsize * 0 true_hidden = unpert_all_hidden[-1]
# Modify the past if necessary
if i >= args.grad_length:
current_stepsize = args.stepsize * 0
else: else:
current_stepsize = stepsize current_stepsize = args.stepsize
# modify the past if necessary if not perturb or args.num_iterations == 0:
if not perturb or num_iterations == 0: perturbed_past = past
pert_past = past
else: else:
accumulated_hidden = unpert_last_hidden[:, :-1, :] # Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
if past is not None: perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past,
pert_past, _, grad_norms, loss_this_iter = perturb_past( model,
past, prev,
model, args,
last, good_index=good_index,
unpert_past=unpert_past, stepsize=current_stepsize,
unpert_logits=unpert_logits, original_probs=original_probs,
accumulated_hidden=accumulated_hidden, true_past=true_past,
grad_norms=grad_norms, accumulated_hidden=accumulated_hidden,
stepsize=current_stepsize, classifier=classifier,
classifier=classifier, grad_norms=grad_norms)
label_class=label_class, loss_in_time.append(loss_per_iter)
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type, # Piero modified model call
num_iterations=num_iterations, logits, past, pert_all_hidden = model(prev, past=perturbed_past)
kl_scale=kl_scale, # test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
window_length=window_length, # likelywords = torch.topk(test_logits, k=10, dim=-1)
horizon_length=horizon_length, # print(TOKENIZER.decode(likelywords[1].tolist()[0]))
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature
pert_probs = F.softmax(pert_logits, dim=-1)
# compute the discriminator loss using unperturbed hidden
if classifier is not None: if classifier is not None:
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) ce_loss = torch.nn.CrossEntropyLoss()
label = torch.tensor([label_class], device="cuda", dtype=torch.long) predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
unpert_discrim_loss = torch.nn.CrossEntropyLoss()(prediction, label) label = torch.tensor([args.label_class], device='cuda',
print( dtype=torch.long)
"unperturbed discrim loss", true_discrim_loss = ce_loss(predicted_sentiment, label)
unpert_discrim_loss.data.cpu().numpy() print("true discrim loss", true_discrim_loss.data.cpu().numpy())
)
else: else:
unpert_discrim_loss = 0 true_discrim_loss = 0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / args.temperature # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
# Fuse the modified model and original model probabilities log_probs = F.softmax(logits, dim=-1)
# Fuse the modified model and original model
if perturb: if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * ( # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
unpert_probs ** (1 - gm_scale) original_probs = F.softmax(original_probs[:, -1, :], dim=-1)
) # likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale = args.gm_scale
log_probs = ((log_probs ** gm_scale) * (
original_probs ** (1 - gm_scale))) # + SmallConst
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) log_probs = top_k_filter(log_probs, k=args.top_k,
probs=True) # + SmallConst
# rescale if torch.sum(log_probs) <= 1:
if torch.sum(pert_probs) <= 1: log_probs = log_probs / torch.sum(log_probs)
pert_probs = pert_probs / torch.sum(pert_probs)
else: else:
pert_logits = top_k_filter(pert_logits, k=top_k) logits = top_k_filter(logits, k=args.top_k) # + SmallConst
pert_probs = F.softmax(pert_logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
# sample or greedy
if sample: if sample:
last = torch.multinomial(pert_probs, num_samples=1) # likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev = torch.multinomial(log_probs, num_samples=1)
else: else:
_, last = torch.topk(pert_probs, k=1, dim=-1) _, prev = torch.topk(log_probs, k=1, dim=-1)
# if perturb:
# prev = future
output = prev if output is None else torch.cat((output, prev),
dim=1) # update output
print(TOKENIZER.decode(output.tolist()[0]))
# update context/output_so_far appending the new token return output, true_discrim_loss, loss_in_time
output_so_far = (
last if output_so_far is None
else torch.cat((output_so_far, last), dim=1)
)
print(TOKENIZER.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time
def run_model(): def run_model():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--model_path', '-M', type=str, default='gpt2-medium',
"--model_path", help='pretrained model name or path to local checkpoint')
"-M", parser.add_argument('--bag-of-words', '-B', type=str, default=None,
type=str, help='Bags of words used for PPLM-BoW. Multiple BoWs separated by ;')
default="gpt2-medium", parser.add_argument('--discrim', '-D', type=str, default=None,
help="pretrained model name or path to local checkpoint", choices=(
) 'clickbait', 'sentiment', 'toxicity', 'generic'),
parser.add_argument( help='Discriminator to use for loss-type 2')
"--bag_of_words", parser.add_argument('--discrim_weights', type=str, default=None,
"-B", help='Weights for the generic discriminator')
type=str, parser.add_argument('--discrim_meta', type=str, default=None,
default=None, help='Meta information for the generic discriminator')
help="Bags of words used for PPLM-BoW. Either a BOW id (see list in code) or a filepath. Multiple BoWs separated by ;", parser.add_argument('--label_class', type=int, default=-1,
) help='Class label used for the discriminator')
parser.add_argument( parser.add_argument('--stepsize', type=float, default=0.02)
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity"),
help="Discriminator to use for loss-type 2",
)
parser.add_argument(
"--label_class",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100) parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9) parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01) parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda") parser.add_argument('--nocuda', action='store_true', help='no cuda')
parser.add_argument( parser.add_argument('--uncond', action='store_true',
"--uncond", action="store_true", help='Generate from end-of-text as prefix')
help="Generate from end-of-text as prefix" parser.add_argument("--cond_text", type=str, default='The lake',
) help='Prefix texts to condition on')
parser.add_argument( parser.add_argument('--num_iterations', type=int, default=3)
"--cond_text", type=str, default="The lake", parser.add_argument('--grad_length', type=int, default=10000)
help="Prefix texts to condition on" parser.add_argument('--num_samples', type=int, default=1,
) help='Number of samples to generate from the modified latents')
parser.add_argument("--num_iterations", type=int, default=3) parser.add_argument('--horizon_length', type=int, default=1,
parser.add_argument("--grad_length", type=int, default=10000) help='Length of future to optimize over')
parser.add_argument( # parser.add_argument('--force-token', action='store_true', help='no cuda')
"--num_samples", parser.add_argument('--window_length', type=int, default=0,
type=int, help='Length of past which is being optimizer; 0 corresponds to infinite window length')
default=1, parser.add_argument('--decay', action='store_true',
help="Number of samples to generate from the modified latents", help='whether to decay or not')
) parser.add_argument('--gamma', type=float, default=1.5)
parser.add_argument( parser.add_argument('--colorama', action='store_true', help='no cuda')
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
args = parser.parse_args() args = parser.parse_args()
# set Random seed
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
# set the device device = 'cpu' if args.nocuda else 'cuda'
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
# load pretrained model
model = GPT2LMHeadModel.from_pretrained( model = GPT2LMHeadModel.from_pretrained(
args.model_path, args.model_path,
output_hidden_states=True output_hidden_states=True
...@@ -720,63 +673,82 @@ def run_model(): ...@@ -720,63 +673,82 @@ def run_model():
model.to(device) model.to(device)
model.eval() model.eval()
# freeze GPT-2 weights # Freeze GPT-2 weights
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
pass
# figure out conditioning text
if args.uncond: if args.uncond:
tokenized_cond_text = TOKENIZER.encode( seq = [[50256, 50256]]
[TOKENIZER.bos_token]
)
else: else:
raw_text = args.cond_text raw_text = args.cond_text
while not raw_text: while not raw_text:
print("Did you forget to add `--cond_text`? ") print('Did you forget to add `--cond-text`? ')
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text) seq = [[50256] + TOKENIZER.encode(raw_text)]
print("= Prefix of sentence =") collect_gen = dict()
print(TOKENIZER.decode(tokenized_cond_text)) current_index = 0
print() for out in seq:
# generate unperturbed and perturbed texts text = TOKENIZER.decode(out)
print("=" * 40 + " Prefix of sentence " + "=" * 40)
# full_text_generation returns: print(text)
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time print("=" * 80)
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args) out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb(
) model=model, args=args, context=out,
device=device)
# untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0]) text_whole = TOKENIZER.decode(out1.tolist()[0])
print("=" * 80) print("=" * 80)
print("= Unperturbed generated text =") print("=" * 40 + " Whole sentence (Original)" + "=" * 40)
print(unpert_gen_text) print(text_whole)
print() print("=" * 80)
out_perturb_copy = out_perturb
for out_perturb in out_perturb_copy:
# try:
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
# print(text_whole)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
keyword_tokens = [aa[-1][0] for aa in
actual_words] if actual_words else []
output_tokens = out_perturb.tolist()[0]
if args.colorama:
import colorama
text_whole = ''
for out in output_tokens:
if out in keyword_tokens:
text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([out]),
colorama.Style.RESET_ALL)
else:
text_whole += TOKENIZER.decode([out])
else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
generated_texts = [] print(text_whole)
print("=" * 80)
# iterate through the perturbed texts collect_gen[current_index] = [out, out_perturb, out1]
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1)) current_index = current_index + 1
print(unpert_gen_text)
print()
except:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return generated_texts return
if __name__ == "__main__": if __name__ == '__main__':
run_model() run_model()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment