Unverified Commit 98597725 authored by Shikhar Chauhan's avatar Shikhar Chauhan Committed by GitHub
Browse files

(feat): Moving labels to same device as logits for Deit (#22679)

parent 870d91fb
......@@ -764,6 +764,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment