Commit 4f2164e4 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

First cleanup step, changing function names and passing parameters all the way...

First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.
parent 821de121
...@@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False): ...@@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
logits) logits)
def perturb_past(past, model, prev, args, classifier, good_index=None, def perturb_past(
stepsize=0.01, vocab_size=50257, past,
original_probs=None, accumulated_hidden=None, true_past=None, model,
grad_norms=None): prev,
window_length = args.window_length unpert_past=None,
gm_scale, kl_scale = args.gm_scale, args.kl_scale unpert_logits=None,
one_hot_vectors = [] accumulated_hidden=None,
for good_list in good_index: grad_norms=None,
good_list = list(filter(lambda x: len(x) <= 1, good_list)) stepsize=0.01,
good_list = torch.tensor(good_list).cuda() classifier=None,
num_good = good_list.shape[0] label_class=None,
one_hot_good = torch.zeros(num_good, vocab_size).cuda() one_hot_bows_vectors=None,
one_hot_good.scatter_(1, good_list, 1) loss_type=0,
one_hot_vectors.append(one_hot_good) num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
# one_hot_bows_vectors = []
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# good_list = torch.tensor(good_list).cuda()
# num_good = good_list.shape[0]
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
# one_hot_good.scatter_(1, good_list, 1)
# one_hot_bows_vectors.append(one_hot_good)
# Generate inital perturbed past # Generate inital perturbed past
past_perturb_orig = [ past_perturb_orig = [
...@@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
if accumulated_hidden is None: if accumulated_hidden is None:
accumulated_hidden = 0 accumulated_hidden = 0
if args.decay: if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
1:] 1:]
else: else:
...@@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
window_mask = torch.ones_like(past[0]).cuda() window_mask = torch.ones_like(past[0]).cuda()
loss_per_iter = [] loss_per_iter = []
for i in range(args.num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
...@@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
probabs = F.softmax(logits, dim=-1) probabs = F.softmax(logits, dim=-1)
loss = 0.0 loss = 0.0
loss_list = [] loss_list = []
if args.loss_type == 1 or args.loss_type == 3: if loss_type == 1 or loss_type == 3:
for one_hot_good in one_hot_vectors: for one_hot_good in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good)) good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits loss_word = good_logits
loss_word = torch.sum(loss_word) loss_word = torch.sum(loss_word)
...@@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
loss_list.append(loss_word) loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy()) print(" pplm_bow_loss:", loss.data.cpu().numpy())
if args.loss_type == 2 or args.loss_type == 3: if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = true_past new_true_past = unpert_past
for i in range(args.horizon_length): for i in range(horizon_length):
future_probabs = F.softmax(logits, dim=-1) # Get softmax future_probabs = F.softmax(logits, dim=-1) # Get softmax
future_probabs = torch.unsqueeze(future_probabs, dim=1) future_probabs = torch.unsqueeze(future_probabs, dim=1)
...@@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
future_hidden, dim=1) future_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / ( predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + args.horizon_length)) current_length + 1 + horizon_length))
label = torch.tensor([args.label_class], device='cuda', label = torch.tensor([label_class], device='cuda',
dtype=torch.long) dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label) discrim_loss = ce_loss(predicted_sentiment, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
...@@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
p = (F.softmax(original_probs[:, -1, :], dim=-1)) p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type( p = p + SmallConst * (p <= SmallConst).type(
torch.FloatTensor).cuda().detach() torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type( correction = SmallConst * (probabs <= SmallConst).type(
...@@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
loss.backward() loss.backward()
if grad_norms is not None and args.loss_type == 1: if grad_norms is not None and loss_type == 1:
grad_norms = [ grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in for index, p_ in
...@@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None, ...@@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
grad = [ grad = [
-stepsize * (p_.grad * window_mask / grad_norms[ -stepsize * (p_.grad * window_mask / grad_norms[
index] ** args.gamma).data.cpu().numpy() index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)] for index, p_ in enumerate(past_perturb)]
past_perturb_orig = list(map(add, grad, past_perturb_orig)) past_perturb_orig = list(map(add, grad, past_perturb_orig))
...@@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices): ...@@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
return one_hot_bows_vectors return one_hot_bows_vectors
def latent_perturb(model, args, context=None, sample=True, device='cuda'): def full_text_generation(
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
args.discrim, discrim,
args.label_class, label_class,
device device
) )
...@@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'): ...@@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
# good_list = list(filter(lambda x: len(x) <= 1, good_list)) # good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list] # actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index = [] bow_indices = []
actual_words = None actual_words = None
if args.bag_of_words: if bag_of_words:
good_index = get_bag_of_words_indices(args.bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in good_index: for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list)) good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list] good_list]
if args.bag_of_words and classifier: if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
args.loss_type = PPLM_BOW_DISCRIM loss_type = PPLM_BOW_DISCRIM
elif args.bag_of_words: elif bag_of_words:
args.loss_type = PPLM_BOW loss_type = PPLM_BOW
print("Using PPLM-BoW") print("Using PPLM-BoW")
elif classifier is not None: elif classifier is not None:
args.loss_type = PPLM_DISCRIM loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim") print("Using PPLM-Discrim")
else: else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
original, _, _ = sample_from_hidden(model=model, args=args, context=context, original, _, _ = generate_text_pplm(
device=device, model=model,
perturb=False, good_index=good_index, context=context,
classifier=classifier) device=device,
length=length,
perturb=False
)
torch.cuda.empty_cache() torch.cuda.empty_cache()
perturbed_list = [] perturbed_list = []
discrim_loss_list = [] discrim_loss_list = []
loss_in_time_list = [] loss_in_time_list = []
for i in range(args.num_samples): for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model, perturbed, discrim_loss, loss_in_time = generate_text_pplm(
args=args, model=model,
context=context, context=context,
device=device, device=device,
perturb=True, sample=sample,
good_index=good_index, perturb=True,
classifier=classifier) bow_indices=bow_indices,
classifier=classifier,
label_class=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
perturbed_list.append(perturbed) perturbed_list.append(perturbed)
if classifier is not None: if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy()) discrim_loss_list.append(discrim_loss.data.cpu().numpy())
...@@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'): ...@@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
def sample_from_hidden(model, args, classifier, context=None, past=None,
device='cuda', def generate_text_pplm(
sample=True, perturb=True, good_index=None): model,
context=None,
past=None,
device="cuda",
sample=True,
perturb=True,
classifier=None,
label_class=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze( output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None 0) if context else None
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None grad_norms = None
loss_in_time = [] loss_in_time = []
for i in trange(args.length, ascii=True): for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word # Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token # Note that GPT takes 2 inputs: past + current-token
...@@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call # Piero modified model call
_, past, _ = model(output[:, :-1]) _, past, _ = model(output[:, :-1])
original_probs, true_past, unpert_all_hidden = model(output) unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1] true_hidden = unpert_all_hidden[-1]
else: else:
...@@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# true_hidden = model.hidden_states # true_hidden = model.hidden_states
# Piero modified model call # Piero modified model call
original_probs, true_past, unpert_all_hidden = model(output) unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1] true_hidden = unpert_all_hidden[-1]
# Modify the past if necessary # Modify the past if necessary
if i >= args.grad_length: if i >= grad_length:
current_stepsize = args.stepsize * 0 current_stepsize = stepsize * 0
else: else:
current_stepsize = args.stepsize current_stepsize = stepsize
if not perturb or args.num_iterations == 0: if not perturb or num_iterations == 0:
perturbed_past = past perturbed_past = past
else: else:
...@@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
accumulated_hidden = true_hidden[:, :-1, :] accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past, perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
model, past,
prev, model,
args, prev,
good_index=good_index, unpert_past=unpert_past,
stepsize=current_stepsize, unpert_logits=unpert_logits,
original_probs=original_probs, accumulated_hidden=accumulated_hidden,
true_past=true_past, grad_norms=grad_norms,
accumulated_hidden=accumulated_hidden, stepsize=current_stepsize,
classifier=classifier, classifier=classifier,
grad_norms=grad_norms) label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter) loss_in_time.append(loss_per_iter)
# Piero modified model call # Piero modified model call
...@@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1)) predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
label = torch.tensor([args.label_class], device='cuda', label = torch.tensor([label_class], device='cuda',
dtype=torch.long) dtype=torch.long)
true_discrim_loss = ce_loss(predicted_sentiment, label) true_discrim_loss = ce_loss(predicted_sentiment, label)
print("true discrim loss", true_discrim_loss.data.cpu().numpy()) print("true discrim loss", true_discrim_loss.data.cpu().numpy())
...@@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call # Piero modified model call
# hidden = model.hidden_states # update hidden # hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden) # logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / args.temperature # + SmallConst logits = logits[:, -1, :] / temperature # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst # logits = top_k_filter(logits, k=args.top_k) # + SmallConst
...@@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None, ...@@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if perturb: if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
original_probs = F.softmax(original_probs[:, -1, :], dim=-1) unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1) # likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0])) # print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale = args.gm_scale
log_probs = ((log_probs ** gm_scale) * ( log_probs = ((log_probs ** gm_scale) * (
original_probs ** (1 - gm_scale))) # + SmallConst unpert_logits ** (1 - gm_scale))) # + SmallConst
log_probs = top_k_filter(log_probs, k=args.top_k, log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst probs=True) # + SmallConst
if torch.sum(log_probs) <= 1: if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs) log_probs = log_probs / torch.sum(log_probs)
else: else:
logits = top_k_filter(logits, k=args.top_k) # + SmallConst logits = top_k_filter(logits, k=top_k) # + SmallConst
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
if sample: if sample:
...@@ -673,16 +767,16 @@ def run_model(): ...@@ -673,16 +767,16 @@ def run_model():
collect_gen = dict() collect_gen = dict()
current_index = 0 current_index = 0
for out in seq: for tokenized_cond_text in seq:
text = TOKENIZER.decode(out) text = TOKENIZER.decode(tokenized_cond_text)
print("=" * 40 + " Prefix of sentence " + "=" * 40) print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text) print(text)
print("=" * 80) print("=" * 80)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb( out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
model=model, args=args, context=out, model=model, context=tokenized_cond_text, device=device, **vars(args)
device=device) )
text_whole = TOKENIZER.decode(out1.tolist()[0]) text_whole = TOKENIZER.decode(out1.tolist()[0])
...@@ -712,20 +806,20 @@ def run_model(): ...@@ -712,20 +806,20 @@ def run_model():
import colorama import colorama
text_whole = '' text_whole = ''
for out in output_tokens: for tokenized_cond_text in output_tokens:
if out in keyword_tokens: if tokenized_cond_text in keyword_tokens:
text_whole += '%s%s%s' % ( text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([out]), colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
colorama.Style.RESET_ALL) colorama.Style.RESET_ALL)
else: else:
text_whole += TOKENIZER.decode([out]) text_whole += TOKENIZER.decode([tokenized_cond_text])
else: else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0]) text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
print(text_whole) print(text_whole)
print("=" * 80) print("=" * 80)
collect_gen[current_index] = [out, out_perturb, out1] collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
current_index = current_index + 1 current_index = current_index + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment