Unverified Commit 52708d26 authored by chutaklee's avatar chutaklee Committed by GitHub
Browse files

Fix PPLM (#8779)



* Fix pplm

* fix style

* make style
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8f07f5c4
......@@ -154,7 +154,8 @@ def perturb_past(
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
_, _, _, curr_length, _ = curr_perturbation[0].shape
all_logits, _, all_hidden = model(last, past=perturbed_past)
lm_output = model(last, past_key_values=perturbed_past)
all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"]
hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
......@@ -179,7 +180,8 @@ def perturb_past(
wte = model.resize_token_embeddings()
for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds)
curr_unpert_past, curr_all_hidden = lm_output["past_key_values"], lm_output["hidden_states"]
curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
......@@ -462,9 +464,14 @@ def generate_text_pplm(
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1])
past = model(output_so_far[:, :-1])["past_key_values"]
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
lm_output = model(output_so_far)
unpert_logits, unpert_past, unpert_all_hidden = (
lm_output["logits"],
lm_output["past_key_values"],
lm_output["hidden_states"],
)
unpert_last_hidden = unpert_all_hidden[-1]
# check if we are abowe grad max length
......@@ -507,7 +514,11 @@ def generate_text_pplm(
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
lm_output = model(last, past_key_values=pert_past)
pert_logits, past = (
lm_output["logits"],
lm_output["past_key_values"],
)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
for token_idx in set(output_so_far[0].tolist()):
......
......@@ -64,7 +64,7 @@ class Discriminator(torch.nn.Module):
def avg_representation(self, x):
mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
hidden, _ = self.encoder.transformer(x)
hidden = self.encoder.transformer(x)["last_hidden_state"]
masked_hidden = hidden * mask
avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
return avg_hidden
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment