"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4df6b59318db0b74eae0651229b59576e8ee326d"
Unverified Commit a92e0ad2 authored by кѳѳsнī's avatar кѳѳsнī Committed by GitHub
Browse files

Enable training Llama with model or pipeline parallelism (#22329)



* Llama - Move target tokens to final pipeline device if needed

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 502fec77
...@@ -783,7 +783,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -783,7 +783,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment