Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12b66215
Unverified
Commit
12b66215
authored
Mar 05, 2021
by
lewtun
Committed by
GitHub
Mar 05, 2021
Browse files
Fix example of custom Trainer to reflect signature of compute_loss (#10537)
parent
093b88f4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+14
-8
No files found.
docs/source/main_classes/trainer.rst
View file @
12b66215
...
@@ -23,14 +23,14 @@ customization during training.
...
@@ -23,14 +23,14 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop support
ing the
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop
which
support
s
previous
features. To inject custom behavior you can subclass them and override the following methods:
the above
features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaluation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **log** -- Logs information on the various objects watching training.
- **create_optimizer_and_scheduler** -- Setup
s
the optimizer and learning rate scheduler if they were not passed at
- **create_optimizer_and_scheduler** -- Set
s
up the optimizer and learning rate scheduler if they were not passed at
init.
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **training_step** -- Performs a training step.
- **training_step** -- Performs a training step.
...
@@ -39,17 +39,23 @@ previous features. To inject custom behavior you can subclass them and override
...
@@ -39,17 +39,23 @@ previous features. To inject custom behavior you can subclass them and override
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function for multi-label
classification:
.. code-block:: python
.. code-block:: python
import torch
from transformers import Trainer
from transformers import Trainer
class MyTrainer(Trainer):
def compute_loss(self, model, inputs):
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
labels = inputs.pop("labels")
outputs = model(**inputs)
outputs = model(**inputs)
logits = outputs[0]
logits = outputs.logits
return my_custom_loss(logits, labels)
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment