Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4cbd50e6
Unverified
Commit
4cbd50e6
authored
Sep 11, 2020
by
Sylvain Gugger
Committed by
GitHub
Sep 11, 2020
Browse files
Compute loss method (#7074)
parent
ae736163
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
8 deletions
+28
-8
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+13
-0
src/transformers/trainer.py
src/transformers/trainer.py
+15
-8
No files found.
docs/source/main_classes/trainer.rst
View file @
4cbd50e6
...
...
@@ -21,12 +21,25 @@ previous features. To inject custom behavior you can subclass them and override
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **run_model** (TensorFlow only) -- Basic pass through the model.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function:
.. code-block:: python
from transformers import Trainer
class MyTrainer(Trainer):
def compute_loss(self, model, inputs):
labels = inputs.pop("labels")
outputs = models(**inputs)
logits = outputs[0]
return my_custom_loss(logits, labels)
``Trainer``
~~~~~~~~~~~
...
...
src/transformers/trainer.py
View file @
4cbd50e6
...
...
@@ -1024,15 +1024,9 @@ class Trainer:
if
self
.
args
.
fp16
and
_use_native_amp
:
with
autocast
():
outputs
=
model
(
**
inputs
)
loss
=
outputs
[
0
]
loss
=
self
.
compute_loss
(
model
,
inputs
)
else
:
outputs
=
model
(
**
inputs
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss
=
outputs
[
0
]
if
self
.
args
.
past_index
>=
0
:
self
.
_past
=
outputs
[
self
.
args
.
past_index
]
loss
=
self
.
compute_loss
(
model
,
inputs
)
if
self
.
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel training
...
...
@@ -1050,6 +1044,19 @@ class Trainer:
return
loss
.
detach
()
def
compute_loss
(
self
,
model
,
inputs
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs
=
model
(
**
inputs
)
# Save past state if it exists
if
self
.
args
.
past_index
>=
0
:
self
.
_past
=
outputs
[
self
.
args
.
past_index
]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return
outputs
[
0
]
def
is_local_master
(
self
)
->
bool
:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment