Unverified Commit 151425dd authored by Shahad Mahmud's avatar Shahad Mahmud Committed by GitHub
Browse files

Model parallelism: Moving labels to same devices as the logits are (#22691)

Model parallelism correct labels device
parent 6daa9cb5
...@@ -999,6 +999,8 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -999,6 +999,8 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous() labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(shifted_prediction_scores.device)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1114,6 +1116,8 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): ...@@ -1114,6 +1116,8 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1224,6 +1228,8 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel): ...@@ -1224,6 +1228,8 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1337,6 +1343,8 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel): ...@@ -1337,6 +1343,8 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(reshaped_logits.device)
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
if not return_dict: if not return_dict:
...@@ -1421,6 +1429,8 @@ class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel): ...@@ -1421,6 +1429,8 @@ class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
......
...@@ -1032,6 +1032,8 @@ class EsmForMaskedLM(EsmPreTrainedModel): ...@@ -1032,6 +1032,8 @@ class EsmForMaskedLM(EsmPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1131,6 +1133,8 @@ class EsmForSequenceClassification(EsmPreTrainedModel): ...@@ -1131,6 +1133,8 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1228,6 +1232,8 @@ class EsmForTokenClassification(EsmPreTrainedModel): ...@@ -1228,6 +1232,8 @@ class EsmForTokenClassification(EsmPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
......
...@@ -1863,6 +1863,8 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1863,6 +1863,8 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1952,6 +1954,8 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1952,6 +1954,8 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -2217,6 +2221,8 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2217,6 +2221,8 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -2329,6 +2335,8 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2329,6 +2335,8 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(reshaped_logits.device)
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
if not return_dict: if not return_dict:
......
...@@ -2074,6 +2074,8 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -2074,6 +2074,8 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100) loss_fct = CrossEntropyLoss(ignore_index=-100)
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment