"docs/vscode:/vscode.git/clone" did not exist on "9ad830596d91fee96402e9af4b32a167e9f349dd"
Unverified Commit 94980b52 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #404 from CatalinVoss/fix_lm_loss

Fix Language Modeling Loss
parents 9ca25ce8 01520d54
......@@ -617,8 +617,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
return lm_logits, presents
......@@ -690,8 +696,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
......
......@@ -716,8 +716,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
return lm_logits
......@@ -803,8 +809,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment