"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c5f3da0f089a8dbcc20593ec91d17838ad2b8769"
Unverified Commit 17503b00 authored by jprivera44's avatar jprivera44 Committed by GitHub
Browse files

Added parallel device usage for GPT-J (#22713)

parent b76e6ebd
...@@ -1012,6 +1012,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -1012,6 +1012,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(pooled_logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment