Unverified Commit aab14120 authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

Moved labels to enable parallelism pipeline in Luke model (#22909)

parent 397720fb
...@@ -1370,6 +1370,8 @@ class LukeForMaskedLM(LukePreTrainedModel): ...@@ -1370,6 +1370,8 @@ class LukeForMaskedLM(LukePreTrainedModel):
mlm_loss = None mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state) logits = self.lm_head(outputs.last_hidden_state)
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1)) mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
if loss is None: if loss is None:
loss = mlm_loss loss = mlm_loss
...@@ -1505,6 +1507,8 @@ class LukeForEntityClassification(LukePreTrainedModel): ...@@ -1505,6 +1507,8 @@ class LukeForEntityClassification(LukePreTrainedModel):
if labels is not None: if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1: if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels) loss = nn.functional.cross_entropy(logits, labels)
else: else:
...@@ -1623,6 +1627,8 @@ class LukeForEntityPairClassification(LukePreTrainedModel): ...@@ -1623,6 +1627,8 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
if labels is not None: if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1: if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels) loss = nn.functional.cross_entropy(logits, labels)
else: else:
...@@ -1765,6 +1771,8 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1765,6 +1771,8 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise. # cross entropy is used otherwise.
if labels.ndim == 2: if labels.ndim == 2:
...@@ -1862,6 +1870,8 @@ class LukeForSequenceClassification(LukePreTrainedModel): ...@@ -1862,6 +1870,8 @@ class LukeForSequenceClassification(LukePreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1975,6 +1985,8 @@ class LukeForTokenClassification(LukePreTrainedModel): ...@@ -1975,6 +1985,8 @@ class LukeForTokenClassification(LukePreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
...@@ -2216,6 +2228,8 @@ class LukeForMultipleChoice(LukePreTrainedModel): ...@@ -2216,6 +2228,8 @@ class LukeForMultipleChoice(LukePreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment