Commit 5938f31f authored by Catalin Voss's avatar Catalin Voss
Browse files

Fix c/p typo from my experiment code

parent 7797d21b
......@@ -619,7 +619,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1]
shift_labels = torch_batch[:, 1:]
shift_labels = lm_labels[:, 1:]
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment