Unverified Commit 98597725 authored by Shikhar Chauhan's avatar Shikhar Chauhan Committed by GitHub
Browse files

(feat): Moving labels to same device as logits for Deit (#22679)

parent 870d91fb
...@@ -764,6 +764,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel): ...@@ -764,6 +764,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment