"examples/vscode:/vscode.git/clone" did not exist on "ac98a88fbc6377f93e8b7fbd244b0c3331bb82a0"
Unverified Commit d7a4f5be authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`T5`] Enable naive Pipeline Parallelism training for T5 (#22535)

* enable PP for T5

* make fixup

* fix failing tests
parent cab048fb
......@@ -1778,6 +1778,8 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
......
......@@ -1746,6 +1746,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment