Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
80f72960
Unverified
Commit
80f72960
authored
Jan 19, 2022
by
NielsRogge
Committed by
GitHub
Jan 19, 2022
Browse files
Update Trainer code example (#15070)
* Update code example * Fix code quality * Add comment
parent
ac227093
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
docs/source/main_classes/trainer.mdx
docs/source/main_classes/trainer.mdx
+6
-6
No files found.
docs/source/main_classes/trainer.mdx
View file @
80f72960
...
...
@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure:
</Tip>
Here is an example of how to customize [`Trainer`]
using a custom loss function for multi-label classification
:
Here is an example of how to customize [`Trainer`]
to use a weighted loss (useful when you have an unbalanced training set)
:
```python
from torch import nn
from transformers import Trainer
class
Multilabel
Trainer(Trainer):
class
Custom
Trainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(
logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels)
)
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment