Unverified Commit a7ff2f23 authored by Erich Schubert's avatar Erich Schubert Committed by GitHub
Browse files

Move misplaced line (#29117)

Move misplaced line, improve code comment
parent 9094abe8
...@@ -1176,11 +1176,11 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1176,11 +1176,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
# Enable model parallelism # Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device) shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment