"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "502d5811505c806f95b92ae777388e9e6d3532fd"
Commit 9f693a0c authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Cleaned generate_text_pplm. Identical output as before.

parent 61a12f79
......@@ -471,59 +471,49 @@ def generate_text_pplm(
decay=False,
gamma=1.5,
):
output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(
0) if context else None
output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
if context
else None
)
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
if past is None and output is not None:
prev = output[:, -1:]
# _, past = model(output[:, :-1])
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Piero modified model call
_, past, _ = model(output[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
else:
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Note that GPT takes 2 inputs: past + current_token
# Piero modified model call
unpert_logits, unpert_past, unpert_all_hidden = model(output)
true_hidden = unpert_all_hidden[-1]
# run model forward to obtain unperturbed
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
_, past, _ = model(output_so_far[:, :-1])
# Modify the past if necessary
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
unpert_last_hidden = unpert_all_hidden[-1]
# check if we are abowe grad max length
if i >= grad_length:
current_stepsize = stepsize * 0
else:
current_stepsize = stepsize
# modify the past if necessary
if not perturb or num_iterations == 0:
perturbed_past = past
pert_past = past
else:
# Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden = true_hidden[:, :-1, :]
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
perturbed_past, _, grad_norms, loss_per_iter = perturb_past(
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
prev,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
......@@ -540,68 +530,59 @@ def generate_text_pplm(
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_per_iter)
loss_in_time.append(loss_this_iter)
# Piero modified model call
logits, past, pert_all_hidden = model(prev, past=perturbed_past)
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(test_logits, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
predicted_sentiment = classifier(torch.mean(true_hidden, dim=1))
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([label_class], device='cuda',
dtype=torch.long)
true_discrim_loss = ce_loss(predicted_sentiment, label)
print("true discrim loss", true_discrim_loss.data.cpu().numpy())
unpert_discrim_loss = ce_loss(prediction, label)
print(
"unperturbed discrim loss",
unpert_discrim_loss.data.cpu().numpy()
)
else:
true_discrim_loss = 0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / temperature # + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1)
unpert_discrim_loss = 0
# Fuse the modified model and original model
if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
log_probs = ((log_probs ** gm_scale) * (
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
log_probs = top_k_filter(log_probs, k=top_k,
pert_probs = ((pert_probs ** gm_scale) * (
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k,
probs=True) # + SMALL_CONST
if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs)
# rescale
if torch.sum(pert_probs) <= 1:
pert_probs = pert_probs / torch.sum(pert_probs)
else:
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1)
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
# sample or greedy
if sample:
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev = torch.multinomial(log_probs, num_samples=1)
last = torch.multinomial(pert_probs, num_samples=1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
# if perturb:
# prev = future
output = prev if output is None else torch.cat((output, prev),
dim=1) # update output
print(TOKENIZER.decode(output.tolist()[0]))
return output, true_discrim_loss, loss_in_time
_, last = torch.topk(pert_probs, k=1, dim=-1)
# update context/output_so_far appending the new token
output_so_far = (
last if output_so_far is None
else torch.cat((output_so_far, last), dim=1)
)
print(TOKENIZER.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time
def run_model():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment