Commit 4f2164e4 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

First cleanup step, changing function names and passing parameters all the way...

First cleanup step, changing function names and passing parameters all the way through without using args. Identical output as before.
parent 821de121
......@@ -109,20 +109,40 @@ def top_k_filter(logits, k, probs=False):
logits)
def perturb_past(past, model, prev, args, classifier, good_index=None,
stepsize=0.01, vocab_size=50257,
original_probs=None, accumulated_hidden=None, true_past=None,
grad_norms=None):
window_length = args.window_length
gm_scale, kl_scale = args.gm_scale, args.kl_scale
one_hot_vectors = []
for good_list in good_index:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
good_list = torch.tensor(good_list).cuda()
num_good = good_list.shape[0]
one_hot_good = torch.zeros(num_good, vocab_size).cuda()
one_hot_good.scatter_(1, good_list, 1)
one_hot_vectors.append(one_hot_good)
def perturb_past(
past,
model,
prev,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
classifier=None,
label_class=None,
one_hot_bows_vectors=None,
loss_type=0,
num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
# one_hot_bows_vectors = []
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# good_list = torch.tensor(good_list).cuda()
# num_good = good_list.shape[0]
# one_hot_good = torch.zeros(num_good, vocab_size).cuda()
# one_hot_good.scatter_(1, good_list, 1)
# one_hot_bows_vectors.append(one_hot_good)
# Generate inital perturbed past
past_perturb_orig = [
......@@ -132,7 +152,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
if accumulated_hidden is None:
accumulated_hidden = 0
if args.decay:
if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
1:]
else:
......@@ -160,7 +180,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
window_mask = torch.ones_like(past[0]).cuda()
loss_per_iter = []
for i in range(args.num_iterations):
for i in range(num_iterations):
print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
......@@ -183,8 +203,8 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
probabs = F.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
if args.loss_type == 1 or args.loss_type == 3:
for one_hot_good in one_hot_vectors:
if loss_type == 1 or loss_type == 3:
for one_hot_good in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits
loss_word = torch.sum(loss_word)
......@@ -194,10 +214,10 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if args.loss_type == 2 or args.loss_type == 3:
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = true_past
for i in range(args.horizon_length):
new_true_past = unpert_past
for i in range(horizon_length):
future_probabs = F.softmax(logits, dim=-1) # Get softmax
future_probabs = torch.unsqueeze(future_probabs, dim=1)
......@@ -217,9 +237,9 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
future_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + args.horizon_length))
current_length + 1 + horizon_length))
label = torch.tensor([args.label_class], device='cuda',
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
......@@ -228,7 +248,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
kl_loss = 0.0
if kl_scale > 0.0:
p = (F.softmax(original_probs[:, -1, :], dim=-1))
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type(
torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type(
......@@ -244,7 +264,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
loss.backward()
if grad_norms is not None and args.loss_type == 1:
if grad_norms is not None and loss_type == 1:
grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in
......@@ -255,7 +275,7 @@ def perturb_past(past, model, prev, args, classifier, good_index=None,
grad = [
-stepsize * (p_.grad * window_mask / grad_norms[
index] ** args.gamma).data.cpu().numpy()
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)]
past_perturb_orig = list(map(add, grad, past_perturb_orig))
......@@ -347,10 +367,32 @@ def build_bows_one_hot_vectors(bow_indices):
return one_hot_bows_vectors
def latent_perturb(model, args, context=None, sample=True, device='cuda'):
def full_text_generation(
model,
context=None,
num_samples=1,
device="cuda",
sample=True,
discrim=None,
label_class=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
**kwargs
):
classifier, class_id = get_classifier(
args.discrim,
args.label_class,
discrim,
label_class,
device
)
......@@ -422,49 +464,68 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index = []
bow_indices = []
actual_words = None
if args.bag_of_words:
good_index = get_bag_of_words_indices(args.bag_of_words.split(";"))
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
for good_list in good_index:
for good_list in bow_indices:
good_list = list(filter(lambda x: len(x) <= 1, good_list))
actual_words = [(TOKENIZER.decode(ww).strip(), ww) for ww in
good_list]
if args.bag_of_words and classifier:
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
args.loss_type = PPLM_BOW_DISCRIM
loss_type = PPLM_BOW_DISCRIM
elif args.bag_of_words:
args.loss_type = PPLM_BOW
elif bag_of_words:
loss_type = PPLM_BOW
print("Using PPLM-BoW")
elif classifier is not None:
args.loss_type = PPLM_DISCRIM
loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim")
else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)")
original, _, _ = sample_from_hidden(model=model, args=args, context=context,
original, _, _ = generate_text_pplm(
model=model,
context=context,
device=device,
perturb=False, good_index=good_index,
classifier=classifier)
length=length,
perturb=False
)
torch.cuda.empty_cache()
perturbed_list = []
discrim_loss_list = []
loss_in_time_list = []
for i in range(args.num_samples):
perturbed, discrim_loss, loss_in_time = sample_from_hidden(model=model,
args=args,
for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
context=context,
device=device,
sample=sample,
perturb=True,
good_index=good_index,
classifier=classifier)
bow_indices=bow_indices,
classifier=classifier,
label_class=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
perturbed_list.append(perturbed)
if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy())
......@@ -475,15 +536,40 @@ def latent_perturb(model, args, context=None, sample=True, device='cuda'):
return original, perturbed_list, discrim_loss_list, loss_in_time_list, actual_words
def sample_from_hidden(model, args, classifier, context=None, past=None,
device='cuda',
sample=True, perturb=True, good_index=None):
def generate_text_pplm(
model,
context=None,
past=None,
device="cuda",
sample=True,
perturb=True,
classifier=None,
label_class=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
horizon_length=1,
decay=False,
gamma=1.5,
):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
loss_in_time = []
for i in trange(args.length, ascii=True):
for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
......@@ -497,7 +583,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call
_, past, _ = model(output[:, :-1])
original_probs, true_past, unpert_all_hidden = model(output)
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
else:
......@@ -505,17 +591,17 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# true_hidden = model.hidden_states
# Piero modified model call
original_probs, true_past, unpert_all_hidden = model(output)
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
# Modify the past if necessary
if i >= args.grad_length:
current_stepsize = args.stepsize * 0
if i >= grad_length:
current_stepsize = stepsize * 0
else:
current_stepsize = args.stepsize
current_stepsize = stepsize
if not perturb or args.num_iterations == 0:
if not perturb or num_iterations == 0:
perturbed_past = past
else:
......@@ -524,17 +610,26 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(past,
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
past,
model,
prev,
args,
good_index=good_index,
stepsize=current_stepsize,
original_probs=original_probs,
true_past=true_past,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
grad_norms=grad_norms)
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter)
# Piero modified model call
......@@ -546,7 +641,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
label = torch.tensor([args.label_class], device='cuda',
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
true_discrim_loss = ce_loss(predicted_sentiment, label)
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
......@@ -556,7 +651,7 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / args.temperature # + SmallConst
logits = logits[:, -1, :] / temperature # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
......@@ -566,22 +661,21 @@ def sample_from_hidden(model, args, classifier, context=None, past=None,
if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
original_probs = F.softmax(original_probs[:, -1, :], dim=-1)
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale = args.gm_scale
log_probs = ((log_probs ** gm_scale) * (
original_probs ** (1 - gm_scale))) # + SmallConst
unpert_logits ** (1 - gm_scale))) # + SmallConst
log_probs = top_k_filter(log_probs, k=args.top_k,
log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst
if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs)
else:
logits = top_k_filter(logits, k=args.top_k) # + SmallConst
logits = top_k_filter(logits, k=top_k) # + SmallConst
log_probs = F.softmax(logits, dim=-1)
if sample:
......@@ -673,16 +767,16 @@ def run_model():
collect_gen = dict()
current_index = 0
for out in seq:
for tokenized_cond_text in seq:
text = TOKENIZER.decode(out)
text = TOKENIZER.decode(tokenized_cond_text)
print("=" * 40 + " Prefix of sentence " + "=" * 40)
print(text)
print("=" * 80)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = latent_perturb(
model=model, args=args, context=out,
device=device)
out1, out_perturb, discrim_loss_list, loss_in_time_list, actual_words = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
)
text_whole = TOKENIZER.decode(out1.tolist()[0])
......@@ -712,20 +806,20 @@ def run_model():
import colorama
text_whole = ''
for out in output_tokens:
if out in keyword_tokens:
for tokenized_cond_text in output_tokens:
if tokenized_cond_text in keyword_tokens:
text_whole += '%s%s%s' % (
colorama.Fore.GREEN, TOKENIZER.decode([out]),
colorama.Fore.GREEN, TOKENIZER.decode([tokenized_cond_text]),
colorama.Style.RESET_ALL)
else:
text_whole += TOKENIZER.decode([out])
text_whole += TOKENIZER.decode([tokenized_cond_text])
else:
text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
print(text_whole)
print("=" * 80)
collect_gen[current_index] = [out, out_perturb, out1]
collect_gen[current_index] = [tokenized_cond_text, out_perturb, out1]
current_index = current_index + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment