Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12b66215
Unverified
Commit
12b66215
authored
Mar 05, 2021
by
lewtun
Committed by
GitHub
Mar 05, 2021
Browse files
Fix example of custom Trainer to reflect signature of compute_loss (#10537)
parent
093b88f4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+14
-8
No files found.
docs/source/main_classes/trainer.rst
View file @
12b66215
...
...
@@ -23,14 +23,14 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop support
ing the
previous
features. To inject custom behavior you can subclass them and override the following methods:
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop
which
support
s
the above
features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **create_optimizer_and_scheduler** -- Setup
s
the optimizer and learning rate scheduler if they were not passed at
- **create_optimizer_and_scheduler** -- Set
s
up the optimizer and learning rate scheduler if they were not passed at
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **training_step** -- Performs a training step.
...
...
@@ -39,17 +39,23 @@ previous features. To inject custom behavior you can subclass them and override
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label
classification:
.. code-block:: python
import torch
from transformers import Trainer
class MyTrainer(Trainer):
def compute_loss(self, model, inputs):
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs[0]
return my_custom_loss(logits, labels)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment