"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7d5ce6802ec5bab29d60e3501337d3477f31b866"
Unverified Commit 15641892 authored by Kaustubh's avatar Kaustubh Committed by GitHub
Browse files

feat(model parallelism): moving the labels to the same device as the logits...

feat(model parallelism): moving the labels to the same device as the logits for gpt2 and bart (#22591)
parent e577bd0f
...@@ -1398,6 +1398,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1398,6 +1398,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
labels = labels.to(lm_logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
...@@ -1553,6 +1554,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1553,6 +1554,7 @@ class BartForSequenceClassification(BartPretrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1896,6 +1898,7 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1896,6 +1898,7 @@ class BartForCausalLM(BartPretrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -2581,6 +2581,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2581,6 +2581,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
labels = labels.to(lm_logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
...@@ -2735,6 +2736,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2735,6 +2736,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
......
...@@ -1596,6 +1596,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1596,6 +1596,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -1563,6 +1563,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1563,6 +1563,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -1098,6 +1098,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1098,6 +1098,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
...@@ -1318,6 +1320,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1318,6 +1320,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
lm_loss = None lm_loss = None
if labels is not None: if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
...@@ -1569,6 +1572,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel): ...@@ -1569,6 +1572,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
...@@ -1715,6 +1715,7 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1715,6 +1715,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -1528,6 +1528,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1528,6 +1528,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1866,6 +1867,7 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1866,6 +1867,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -1694,6 +1694,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1694,6 +1694,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -1499,6 +1499,7 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ...@@ -1499,6 +1499,7 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1713,6 +1714,7 @@ class PLBartForCausalLM(PLBartPreTrainedModel): ...@@ -1713,6 +1714,7 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment