Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d61d7476
Unverified
Commit
d61d7476
authored
May 26, 2023
by
amitportnoy
Committed by
GitHub
May 26, 2023
Browse files
Update trainer.mdx class_weights example (#23787)
class_weights tensor should follow model's device
parent
4d9b76a8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
1 deletion
+1
-1
docs/source/en/main_classes/trainer.mdx
docs/source/en/main_classes/trainer.mdx
+1
-1
No files found.
docs/source/en/main_classes/trainer.mdx
View file @
d61d7476
...
...
@@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]
, device=model.device
))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment