Commit a59fdd16 authored by Piero Molino's avatar Piero Molino Committed by Julien Chaumond
Browse files

generate_text_pplm now works with batch_size > 1

parent 893d0d64
...@@ -231,7 +231,8 @@ def perturb_past( ...@@ -231,7 +231,8 @@ def perturb_past(
prediction = classifier(new_accumulated_hidden / prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length)) (curr_length + 1 + horizon_length))
label = torch.tensor([class_label], device=device, label = torch.tensor(prediction.shape[0] * [class_label],
device=device,
dtype=torch.long) dtype=torch.long)
discrim_loss = ce_loss(prediction, label) discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
...@@ -508,11 +509,12 @@ def generate_text_pplm( ...@@ -508,11 +509,12 @@ def generate_text_pplm(
gm_scale=0.9, gm_scale=0.9,
kl_scale=0.01, kl_scale=0.01,
): ):
output_so_far = ( output_so_far = None
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) if context:
if context context_t = torch.tensor(context, device=device, dtype=torch.long)
else None while len(context_t.shape) < 2:
) context_t = context_t.unsqueeze(0)
output_so_far = context_t
# collect one hot vectors for bags of words # collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment