Commit 7797d21b authored by Catalin Voss's avatar Catalin Voss
Browse files

Fix GPT2 language modeling loss computation

parent f3e54048
......@@ -617,8 +617,16 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1]
shift_labels = torch_batch[:, 1:]
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
shift_labels.view(-1))
return loss
return lm_logits, presents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment