Unverified Commit 8a817e1e authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

moved labels to the same device as logits for LILT model (#22898)

parent 515d6a55
......@@ -924,6 +924,8 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -1046,6 +1048,8 @@ class LiltForTokenClassification(LiltPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment