Unverified Commit 91d6a593 authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

moved labels to the same device as logits for OTP, CODEGEN ,gptj and pixel2struct model (#22872)

* moved labels to the same device as logits for OTP model

* moved labels to the same device as logits for CODEGEN model

* Update modeling_codegen.py

* moved labels to the same device as logits for gptj and pix2struct model

* Update modeling_pix2struct.py
parent 4116d1ec
......@@ -693,6 +693,8 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -876,6 +876,8 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -951,6 +951,8 @@ class OPTForCausalLM(OPTPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -1552,6 +1552,8 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean", label_smoothing=0.1)
masked_labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment