Commit 472857c4 authored by Catalin Voss's avatar Catalin Voss
Browse files

Fix typo syntax err (sorry, c/p from my repo)

parent 2e6f5ffb
...@@ -625,7 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -625,7 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}] # in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way. # We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
return lm_logits, presents return lm_logits, presents
......
...@@ -724,7 +724,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -724,7 +724,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}] # in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way. # We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
return lm_logits return lm_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment