"vscode:/vscode.git/clone" did not exist on "8bcb37bfb80d77e06001f989ad982c9961a69c31"
Commit 0dd796e3 authored by Catalin Voss's avatar Catalin Voss
Browse files

Also fix loss function issue with the double head models

parent 472857c4
......@@ -698,8 +698,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
......
......@@ -811,8 +811,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment