Unverified Commit 0224aaf6 authored by Mayank Agarwal's avatar Mayank Agarwal Committed by GitHub
Browse files

Enable naive Pipeline Parallelism training for Gpt neox japanese and san japanese (#22702)

Move labels to same device as logits
parent 28c19ab5
...@@ -682,6 +682,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): ...@@ -682,6 +682,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
lm_loss = None lm_loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one # we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous() shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous() labels = labels[:, 1:].contiguous()
......
...@@ -1236,6 +1236,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): ...@@ -1236,6 +1236,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
router_probs = None router_probs = None
aux_loss = None aux_loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if output_router_logits: if output_router_logits:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment