Unverified Commit 337225ec authored by Elabonga Atuo's avatar Elabonga Atuo Committed by GitHub
Browse files

feat(model parallelism): move labels to the same device as logits for M2M100 (#22850)

moved logits for m2m_100
parent 6bd8ae26
...@@ -1353,6 +1353,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1353,6 +1353,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
# move labels to the correct device to enable PP
labels = labels.to(lm_logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment