Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
86caab1e
Unverified
Commit
86caab1e
authored
Jul 31, 2020
by
Sylvain Gugger
Committed by
GitHub
Jul 31, 2020
Browse files
Harmonize both Trainers API (#6157)
* Harmonize both Trainers API * Fix test * main_prcess -> process_zero
parent
603cd81a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
125 additions
and
97 deletions
+125
-97
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+17
-0
src/transformers/trainer.py
src/transformers/trainer.py
+92
-83
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+16
-14
No files found.
docs/source/main_classes/trainer.rst
View file @
86caab1e
...
...
@@ -11,6 +11,23 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
previous features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **run_model** (TensorFlow only) -- Basic pass through the model.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
``Trainer``
~~~~~~~~~~~
...
...
src/transformers/trainer.py
View file @
86caab1e
...
...
@@ -172,18 +172,6 @@ class Trainer:
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
"""
model
:
PreTrainedModel
args
:
TrainingArguments
data_collator
:
DataCollator
train_dataset
:
Optional
[
Dataset
]
eval_dataset
:
Optional
[
Dataset
]
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
prediction_loss_only
:
bool
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
None
global_step
:
Optional
[
int
]
=
None
epoch
:
Optional
[
float
]
=
None
def
__init__
(
self
,
model
:
PreTrainedModel
,
...
...
@@ -194,7 +182,7 @@ class Trainer:
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
prediction_loss_only
=
False
,
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
,
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
None
,
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
(
None
,
None
),
):
self
.
model
=
model
.
to
(
args
.
device
)
self
.
args
=
args
...
...
@@ -203,10 +191,9 @@ class Trainer:
self
.
eval_dataset
=
eval_dataset
self
.
compute_metrics
=
compute_metrics
self
.
prediction_loss_only
=
prediction_loss_only
self
.
optimizers
=
optimizers
if
tb_writer
is
not
None
:
self
.
optimizer
,
self
.
lr_scheduler
=
optimizers
self
.
tb_writer
=
tb_writer
el
if
is_tensorboard_available
()
and
self
.
is_world_
mast
er
():
if
tb_writer
is
None
and
is_tensorboard_available
()
and
self
.
is_world_
process_z
er
o
():
self
.
tb_writer
=
SummaryWriter
(
log_dir
=
self
.
args
.
logging_dir
)
if
not
is_tensorboard_available
():
logger
.
warning
(
...
...
@@ -221,7 +208,7 @@ class Trainer:
)
set_seed
(
self
.
args
.
seed
)
# Create output directory if needed
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
if
is_torch_tpu_available
():
# Set an xla_device flag on the model's config.
...
...
@@ -236,6 +223,8 @@ class Trainer:
),
FutureWarning
,
)
self
.
global_step
=
None
self
.
epoch
=
None
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
...
...
@@ -333,18 +322,14 @@ class Trainer:
drop_last
=
self
.
args
.
dataloader_drop_last
,
)
def
get_optimizers
(
self
,
num_training_steps
:
int
)
->
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]:
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
):
"""
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if
self
.
optimizers
is
not
None
:
return
self
.
optimizers
# Prepare optimizer and schedule (linear warmup and decay)
if
self
.
optimizer
is
None
:
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
{
...
...
@@ -356,16 +341,16 @@ class Trainer:
"weight_decay"
:
0.0
,
},
]
optimizer
=
AdamW
(
self
.
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
args
.
learning_rate
,
betas
=
(
self
.
args
.
adam_beta1
,
self
.
args
.
adam_beta2
),
eps
=
self
.
args
.
adam_epsilon
,
)
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
if
self
.
lr_scheduler
is
None
:
self
.
lr_scheduler
=
get_linear_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
return
optimizer
,
scheduler
def
setup_wandb
(
self
):
"""
...
...
@@ -390,7 +375,7 @@ class Trainer:
)
return
self
.
_setup_wandb
()
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
logger
.
info
(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
...
...
@@ -426,7 +411,7 @@ class Trainer:
t_total
=
int
(
len
(
train_dataloader
)
//
self
.
args
.
gradient_accumulation_steps
*
self
.
args
.
num_train_epochs
)
num_train_epochs
=
self
.
args
.
num_train_epochs
optimizer
,
scheduler
=
self
.
get_optimizers
(
num_training_steps
=
t_total
)
self
.
create_
optimizer
_and_
scheduler
(
num_training_steps
=
t_total
)
# Check if saved optimizer or scheduler states exist
if
(
...
...
@@ -435,16 +420,16 @@ class Trainer:
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
self
.
lr_
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
model
=
self
.
model
if
self
.
args
.
fp16
and
_use_apex
:
if
not
is_apex_available
():
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
self
.
args
.
fp16_opt_level
)
model
,
self
.
optimizer
=
amp
.
initialize
(
model
,
self
.
optimizer
,
opt_level
=
self
.
args
.
fp16_opt_level
)
# multi-gpu training (should be after apex fp16 initialization)
if
self
.
args
.
n_gpu
>
1
:
...
...
@@ -506,7 +491,7 @@ class Trainer:
logging_loss
=
0.0
model
.
zero_grad
()
train_iterator
=
trange
(
epochs_trained
,
int
(
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
not
self
.
is_local_
mast
er
()
epochs_trained
,
int
(
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
not
self
.
is_local_
process_z
er
o
()
)
for
epoch
in
train_iterator
:
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
...
...
@@ -516,9 +501,9 @@ class Trainer:
parallel_loader
=
pl
.
ParallelLoader
(
train_dataloader
,
[
self
.
args
.
device
]).
per_device_loader
(
self
.
args
.
device
)
epoch_iterator
=
tqdm
(
parallel_loader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
mast
er
())
epoch_iterator
=
tqdm
(
parallel_loader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
process_z
er
o
())
else
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
mast
er
())
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
process_z
er
o
())
# Reset the past mems state at the beginning of each epoch if necessary.
if
self
.
args
.
past_index
>=
0
:
...
...
@@ -531,7 +516,7 @@ class Trainer:
steps_trained_in_current_epoch
-=
1
continue
tr_loss
+=
self
.
training_step
(
model
,
inputs
,
optimizer
)
tr_loss
+=
self
.
training_step
(
model
,
inputs
)
if
(
step
+
1
)
%
self
.
args
.
gradient_accumulation_steps
==
0
or
(
# last step in epoch but step is always smaller than gradient_accumulation_steps
...
...
@@ -539,23 +524,22 @@ class Trainer:
and
(
step
+
1
)
==
len
(
epoch_iterator
)
):
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
unscale_
(
optimizer
)
self
.
scaler
.
unscale_
(
self
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
elif
self
.
args
.
fp16
and
_use_apex
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
self
.
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
),
self
.
args
.
max_grad_norm
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
if
is_torch_tpu_available
():
xm
.
optimizer_step
(
optimizer
)
xm
.
optimizer_step
(
self
.
optimizer
)
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
step
(
optimizer
)
self
.
scaler
.
step
(
self
.
optimizer
)
self
.
scaler
.
update
()
else
:
optimizer
.
step
()
self
.
optimizer
.
step
()
scheduler
.
step
()
self
.
lr_
scheduler
.
step
()
model
.
zero_grad
()
self
.
global_step
+=
1
self
.
epoch
=
epoch
+
(
step
+
1
)
/
len
(
epoch_iterator
)
...
...
@@ -567,9 +551,9 @@ class Trainer:
logs
[
"loss"
]
=
(
tr_loss
-
logging_loss
)
/
self
.
args
.
logging_steps
# backward compatibility for pytorch schedulers
logs
[
"learning_rate"
]
=
(
scheduler
.
get_last_lr
()[
0
]
self
.
lr_
scheduler
.
get_last_lr
()[
0
]
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.4"
)
else
scheduler
.
get_lr
()[
0
]
else
self
.
lr_
scheduler
.
get_lr
()[
0
]
)
logging_loss
=
tr_loss
...
...
@@ -590,16 +574,16 @@ class Trainer:
self
.
save_model
(
output_dir
)
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
self
.
_rotate_checkpoints
()
if
is_torch_tpu_available
():
xm
.
rendezvous
(
"saving_optimizer_states"
)
xm
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
xm
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
elif
self
.
is_world_
mast
er
():
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
torch
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
xm
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
xm
.
save
(
self
.
lr_
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
elif
self
.
is_world_
process_z
er
o
():
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
torch
.
save
(
self
.
lr_
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
>
self
.
args
.
max_steps
:
epoch_iterator
.
close
()
...
...
@@ -660,7 +644,7 @@ class Trainer:
)
self
.
tb_writer
.
flush
()
if
is_wandb_available
():
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
wandb
.
log
(
logs
,
step
=
self
.
global_step
)
output
=
{
**
logs
,
**
{
"step"
:
self
.
global_step
}}
if
iterator
is
not
None
:
...
...
@@ -684,11 +668,9 @@ class Trainer:
return
inputs
def
training_step
(
self
,
model
:
nn
.
Module
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
optimizer
:
torch
.
optim
.
Optimizer
)
->
float
:
def
training_step
(
self
,
model
:
nn
.
Module
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]])
->
float
:
"""
Perform a training step on
:obj:`model` using obj:`inputs` and :obj:`optimizer`
.
Perform a training step on
a batch of inputs
.
Subclass and override to inject custom behavior.
...
...
@@ -700,19 +682,16 @@ class Trainer:
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer to use to make a step.
Return:
`float`:
The training loss on this batch.
:obj:`float`: The training loss on this batch.
"""
if
hasattr
(
self
,
"_training_step"
):
warnings
.
warn
(
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass."
,
FutureWarning
,
)
return
self
.
_training_step
(
model
,
inputs
,
optimizer
)
return
self
.
_training_step
(
model
,
inputs
,
self
.
optimizer
)
model
.
train
()
inputs
=
self
.
_prepare_inputs
(
inputs
,
model
)
...
...
@@ -738,7 +717,7 @@ class Trainer:
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
scale
(
loss
).
backward
()
elif
self
.
args
.
fp16
and
_use_apex
:
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
else
:
loss
.
backward
()
...
...
@@ -746,6 +725,22 @@ class Trainer:
return
loss
.
item
()
def
is_local_master
(
self
)
->
bool
:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
"""
warnings
.
warn
(
"This method is deprecated, use `Trainer.is_local_process_zero()` instead."
,
FutureWarning
)
return
self
.
is_local_process_zero
()
def
is_local_process_zero
(
self
)
->
bool
:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
"""
if
is_torch_tpu_available
():
return
xm
.
is_master_ordinal
(
local
=
True
)
else
:
...
...
@@ -753,8 +748,20 @@ class Trainer:
def
is_world_master
(
self
)
->
bool
:
"""
This will be True only in one process, even in distributed mode,
even when training on multiple machines.
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
"""
warnings
.
warn
(
"This method is deprecated, use `Trainer.is_world_process_zero()` instead."
,
FutureWarning
)
return
self
.
is_world_process_zero
()
def
is_world_process_zero
(
self
)
->
bool
:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
if
is_torch_tpu_available
():
return
xm
.
is_master_ordinal
(
local
=
False
)
...
...
@@ -770,7 +777,7 @@ class Trainer:
if
is_torch_tpu_available
():
self
.
_save_tpu
(
output_dir
)
elif
self
.
is_world_
mast
er
():
elif
self
.
is_world_
process_z
er
o
():
self
.
_save
(
output_dir
)
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
...
...
@@ -846,6 +853,7 @@ class Trainer:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
...
...
@@ -871,6 +879,7 @@ class Trainer:
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on.
Returns:
`NamedTuple`:
predictions (:obj:`np.ndarray`):
...
...
src/transformers/trainer_tf.py
View file @
86caab1e
...
...
@@ -63,17 +63,6 @@ class TFTrainer:
an instance of :class:`~transformers.WarmUp`.
"""
model
:
TFPreTrainedModel
args
:
TFTrainingArguments
train_dataset
:
Optional
[
tf
.
data
.
Dataset
]
eval_dataset
:
Optional
[
tf
.
data
.
Dataset
]
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
prediction_loss_only
:
bool
tb_writer
:
Optional
[
tf
.
summary
.
SummaryWriter
]
=
None
optimizers
:
Tuple
[
tf
.
keras
.
optimizers
.
Optimizer
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
]
=
(
None
,
None
)
global_step
:
Optional
[
int
]
=
None
epoch_logging
:
Optional
[
float
]
=
None
def
__init__
(
self
,
model
:
TFPreTrainedModel
,
...
...
@@ -325,6 +314,15 @@ class TFTrainer:
return
PredictionOutput
(
predictions
=
preds
,
label_ids
=
label_ids
,
metrics
=
metrics
)
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
"""
if
hasattr
(
self
,
"_log"
):
warnings
.
warn
(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass."
,
...
...
@@ -356,6 +354,7 @@ class TFTrainer:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
...
...
@@ -577,9 +576,12 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior.
Args:
features: the batched features.
labels: the batched labels.
training: run the model in training mode or not
features (:obj:`tf.Tensor`): A batch of input features.
labels (:obj:`tf.Tensor`): A batch of labels.
training (:obj:`bool`): Whether or not to run the model in training mode.
Returns:
A tuple of two :obj:`tf.Tensor`: The loss and logits.
"""
if
hasattr
(
self
,
"_run_model"
):
warnings
.
warn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment