Commit ffc29354 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Fix for making unditioned generation work. Identical output as before.

parent 9f693a0c
......@@ -481,6 +481,7 @@ def generate_text_pplm(
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
......@@ -491,7 +492,8 @@ def generate_text_pplm(
# run model forward to obtain unperturbed
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
_, past, _ = model(output_so_far[:, :-1])
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
unpert_last_hidden = unpert_all_hidden[-1]
......@@ -510,27 +512,30 @@ def generate_text_pplm(
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
if past is not None:
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
)
loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment