Unverified Commit 17503b00 authored by jprivera44's avatar jprivera44 Committed by GitHub
Browse files

Added parallel device usage for GPT-J (#22713)

parent b76e6ebd
......@@ -1012,6 +1012,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment