"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b41cc0b86a93c863afc3eddf7af736807218dbe7"
Commit ffc29354 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Fix for making unditioned generation work. Identical output as before.

parent 9f693a0c
...@@ -481,6 +481,7 @@ def generate_text_pplm( ...@@ -481,6 +481,7 @@ def generate_text_pplm(
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
grad_norms = None grad_norms = None
last = None
unpert_discrim_loss = 0 unpert_discrim_loss = 0
loss_in_time = [] loss_in_time = []
for i in trange(length, ascii=True): for i in trange(length, ascii=True):
...@@ -491,6 +492,7 @@ def generate_text_pplm( ...@@ -491,6 +492,7 @@ def generate_text_pplm(
# run model forward to obtain unperturbed # run model forward to obtain unperturbed
if past is None and output_so_far is not None: if past is None and output_so_far is not None:
last = output_so_far[:, -1:] last = output_so_far[:, -1:]
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1]) _, past, _ = model(output_so_far[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
...@@ -510,6 +512,7 @@ def generate_text_pplm( ...@@ -510,6 +512,7 @@ def generate_text_pplm(
accumulated_hidden = unpert_last_hidden[:, :-1, :] accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
if past is not None:
pert_past, _, grad_norms, loss_this_iter = perturb_past( pert_past, _, grad_norms, loss_this_iter = perturb_past(
past, past,
model, model,
...@@ -531,6 +534,8 @@ def generate_text_pplm( ...@@ -531,6 +534,8 @@ def generate_text_pplm(
gamma=gamma, gamma=gamma,
) )
loss_in_time.append(loss_this_iter) loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past) pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment