Unverified Commit aab14120 authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

Moved labels to enable parallelism pipeline in Luke model (#22909)

parent 397720fb
......@@ -1370,6 +1370,8 @@ class LukeForMaskedLM(LukePreTrainedModel):
mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state)
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
if loss is None:
loss = mlm_loss
......@@ -1505,6 +1507,8 @@ class LukeForEntityClassification(LukePreTrainedModel):
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels)
else:
......@@ -1623,6 +1627,8 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
if labels is not None:
# When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if labels.ndim == 1:
loss = nn.functional.cross_entropy(logits, labels)
else:
......@@ -1765,6 +1771,8 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
# cross entropy is used otherwise.
if labels.ndim == 2:
......@@ -1862,6 +1870,8 @@ class LukeForSequenceClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -1975,6 +1985,8 @@ class LukeForTokenClassification(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......@@ -2216,6 +2228,8 @@ class LukeForMultipleChoice(LukePreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment