Unverified Commit 8a817e1e authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

moved labels to the same device as logits for LILT model (#22898)

parent 515d6a55
...@@ -924,6 +924,8 @@ class LiltForSequenceClassification(LiltPreTrainedModel): ...@@ -924,6 +924,8 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
...@@ -1046,6 +1048,8 @@ class LiltForTokenClassification(LiltPreTrainedModel): ...@@ -1046,6 +1048,8 @@ class LiltForTokenClassification(LiltPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment