Commit ffc29354 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Fix for making unditioned generation work. Identical output as before.

parent 9f693a0c
......@@ -481,6 +481,7 @@ def generate_text_pplm(
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
......@@ -491,6 +492,7 @@ def generate_text_pplm(
# run model forward to obtain unperturbed
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
......@@ -510,6 +512,7 @@ def generate_text_pplm(
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
if past is not None:
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
......@@ -531,6 +534,8 @@ def generate_text_pplm(
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment