Commit 83b1e6ac authored by Rosanne Liu's avatar Rosanne Liu Committed by Julien Chaumond
Browse files

fix the loss backward issue

(cherry picked from commit 566468cc984c6ec7e10dfc62b5b4191781a99cd2)
parent 572c24cf
...@@ -36,6 +36,7 @@ from tqdm import trange ...@@ -36,6 +36,7 @@ from tqdm import trange
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel from transformers.modeling_gpt2 import GPT2LMHeadModel
from IPython import embed
PPLM_BOW = 1 PPLM_BOW = 1
PPLM_DISCRIM = 2 PPLM_DISCRIM = 2
...@@ -246,8 +247,8 @@ def perturb_past( ...@@ -246,8 +247,8 @@ def perturb_past(
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds
) )
# get expected hidden states # get expected hidden states
unpert_hidden = curr_all_hidden[1] unpert_hidden = curr_all_hidden[-1]
accumulated_hidden += torch.sum(unpert_hidden, dim=1) accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach()
prediction = classifier( prediction = classifier(
accumulated_hidden / (curr_length + 1 + horizon_length) accumulated_hidden / (curr_length + 1 + horizon_length)
...@@ -257,7 +258,7 @@ def perturb_past( ...@@ -257,7 +258,7 @@ def perturb_past(
discrim_loss += ce_loss(prediction, label) discrim_loss += ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
if kl_scale > 0.0: if kl_scale >= 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = ( unpert_probs = (
unpert_probs + SMALL_CONST * unpert_probs + SMALL_CONST *
...@@ -270,7 +271,7 @@ def perturb_past( ...@@ -270,7 +271,7 @@ def perturb_past(
torch.FloatTensor torch.FloatTensor
).cuda().detach() ).cuda().detach()
corrected_probs = probs + correction.detach() corrected_probs = probs + correction.detach()
kl_loss += kl_scale * ( kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum() (corrected_probs * (corrected_probs / unpert_probs).log()).sum()
) )
print(' kl_loss', (kl_loss).data.cpu().numpy()) print(' kl_loss', (kl_loss).data.cpu().numpy())
...@@ -280,7 +281,7 @@ def perturb_past( ...@@ -280,7 +281,7 @@ def perturb_past(
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients # compute gradients
loss.backward(retain_graph=True) loss.backward()
# calculate gradient norms # calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW: if grad_norms is not None and loss_type == PPLM_BOW:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment