Unverified Commit 95b6bef6 authored by MichelBartels's avatar MichelBartels Committed by GitHub
Browse files

Align logits and labels in OPT (#17237)

parent a5d18396
...@@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment