"vscode:/vscode.git/clone" did not exist on "e33da0eb32c00796764ccc01d9b71c19c3662fec"
Unverified Commit 91d6a593 authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

moved labels to the same device as logits for OTP, CODEGEN ,gptj and pixel2struct model (#22872)

* moved labels to the same device as logits for OTP model

* moved labels to the same device as logits for CODEGEN model

* Update modeling_codegen.py

* moved labels to the same device as logits for gptj and pix2struct model

* Update modeling_pix2struct.py
parent 4116d1ec
......@@ -693,6 +693,8 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -876,6 +876,8 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -951,6 +951,8 @@ class OPTForCausalLM(OPTPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
......
......@@ -1552,6 +1552,8 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean", label_smoothing=0.1)
masked_labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment