"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "25c451e5a044969eb91e1e481574a2bfca5130ca"
Unverified Commit 80f72960 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update Trainer code example (#15070)

* Update code example

* Fix code quality

* Add comment
parent ac227093
...@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure: ...@@ -47,22 +47,22 @@ when you use it on other models. When using it on your own model, make sure:
</Tip> </Tip>
Here is an example of how to customize [`Trainer`] using a custom loss function for multi-label classification: Here is an example of how to customize [`Trainer`] to use a weighted loss (useful when you have an unbalanced training set):
```python ```python
from torch import nn from torch import nn
from transformers import Trainer from transformers import Trainer
class MultilabelTrainer(Trainer): class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels") labels = inputs.get("labels")
# forward pass
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs.get("logits") logits = outputs.get("logits")
loss_fct = nn.BCEWithLogitsLoss() # compute custom loss (suppose one has 3 labels with different weights)
loss = loss_fct( loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
)
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment