"templates/vscode:/vscode.git/clone" did not exist on "b2b7fc781438c7d1d551cdac0a44af5ca0399797"
Unverified Commit 8a817e1e authored by SUSHMANTH REDDY's avatar SUSHMANTH REDDY Committed by GitHub
Browse files

moved labels to the same device as logits for LILT model (#22898)

parent 515d6a55
......@@ -924,6 +924,8 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......@@ -1046,6 +1048,8 @@ class LiltForTokenClassification(LiltPreTrainedModel):
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment