Unverified Commit 151425dd authored by Shahad Mahmud's avatar Shahad Mahmud Committed by GitHub
Browse files

Model parallelism: Moving labels to same devices as the logits are (#22691)

Model parallelism correct labels device
parent 6daa9cb5
......@@ -999,6 +999,8 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
labels = labels.to(shifted_prediction_scores.device)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
......@@ -1114,6 +1116,8 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
......@@ -1224,6 +1228,8 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -1337,6 +1343,8 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(reshaped_logits.device)
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
......@@ -1421,6 +1429,8 @@ class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
......
......@@ -1032,6 +1032,8 @@ class EsmForMaskedLM(EsmPreTrainedModel):
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
......@@ -1131,6 +1133,8 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -1228,6 +1232,8 @@ class EsmForTokenClassification(EsmPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
......
......@@ -1863,6 +1863,8 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
......@@ -1952,6 +1954,8 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -2217,6 +2221,8 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
......@@ -2329,6 +2335,8 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(reshaped_logits.device)
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
......
......@@ -2074,6 +2074,8 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment