Commit 01520d54 authored by Catalin Voss's avatar Catalin Voss
Browse files

Remove my unhelpful comments :)

parent fda2f623
......@@ -621,9 +621,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
......
......@@ -720,9 +720,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment