Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
86caab1e
Unverified
Commit
86caab1e
authored
Jul 31, 2020
by
Sylvain Gugger
Committed by
GitHub
Jul 31, 2020
Browse files
Harmonize both Trainers API (#6157)
* Harmonize both Trainers API * Fix test * main_prcess -> process_zero
parent
603cd81a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
125 additions
and
97 deletions
+125
-97
docs/source/main_classes/trainer.rst
docs/source/main_classes/trainer.rst
+17
-0
src/transformers/trainer.py
src/transformers/trainer.py
+92
-83
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+16
-14
No files found.
docs/source/main_classes/trainer.rst
View file @
86caab1e
...
@@ -11,6 +11,23 @@ customization during training.
...
@@ -11,6 +11,23 @@ customization during training.
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
<https://github.com/NVIDIA/apex>`__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow.
Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the
previous features. To inject custom behavior you can subclass them and override the following methods:
- **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset.
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **training_step** -- Performs a training step.
- **prediction_step** -- Performs an evaluation/test step.
- **run_model** (TensorFlow only) -- Basic pass through the model.
- **evaluate** -- Runs an evaluation loop and returns metrics.
- **predict** -- Returns predictions (with metrics if labels are available) on a test set.
``Trainer``
``Trainer``
~~~~~~~~~~~
~~~~~~~~~~~
...
...
src/transformers/trainer.py
View file @
86caab1e
...
@@ -172,18 +172,6 @@ class Trainer:
...
@@ -172,18 +172,6 @@ class Trainer:
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
"""
"""
model
:
PreTrainedModel
args
:
TrainingArguments
data_collator
:
DataCollator
train_dataset
:
Optional
[
Dataset
]
eval_dataset
:
Optional
[
Dataset
]
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
prediction_loss_only
:
bool
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
None
global_step
:
Optional
[
int
]
=
None
epoch
:
Optional
[
float
]
=
None
def
__init__
(
def
__init__
(
self
,
self
,
model
:
PreTrainedModel
,
model
:
PreTrainedModel
,
...
@@ -194,7 +182,7 @@ class Trainer:
...
@@ -194,7 +182,7 @@ class Trainer:
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
,
prediction_loss_only
=
False
,
prediction_loss_only
=
False
,
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
,
tb_writer
:
Optional
[
"SummaryWriter"
]
=
None
,
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
None
,
optimizers
:
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]
=
(
None
,
None
),
):
):
self
.
model
=
model
.
to
(
args
.
device
)
self
.
model
=
model
.
to
(
args
.
device
)
self
.
args
=
args
self
.
args
=
args
...
@@ -203,10 +191,9 @@ class Trainer:
...
@@ -203,10 +191,9 @@ class Trainer:
self
.
eval_dataset
=
eval_dataset
self
.
eval_dataset
=
eval_dataset
self
.
compute_metrics
=
compute_metrics
self
.
compute_metrics
=
compute_metrics
self
.
prediction_loss_only
=
prediction_loss_only
self
.
prediction_loss_only
=
prediction_loss_only
self
.
optimizers
=
optimizers
self
.
optimizer
,
self
.
lr_scheduler
=
optimizers
if
tb_writer
is
not
None
:
self
.
tb_writer
=
tb_writer
self
.
tb_writer
=
tb_writer
if
tb_writer
is
None
and
is_tensorboard_available
()
and
self
.
is_world_process_zero
():
elif
is_tensorboard_available
()
and
self
.
is_world_master
():
self
.
tb_writer
=
SummaryWriter
(
log_dir
=
self
.
args
.
logging_dir
)
self
.
tb_writer
=
SummaryWriter
(
log_dir
=
self
.
args
.
logging_dir
)
if
not
is_tensorboard_available
():
if
not
is_tensorboard_available
():
logger
.
warning
(
logger
.
warning
(
...
@@ -221,7 +208,7 @@ class Trainer:
...
@@ -221,7 +208,7 @@ class Trainer:
)
)
set_seed
(
self
.
args
.
seed
)
set_seed
(
self
.
args
.
seed
)
# Create output directory if needed
# Create output directory if needed
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
args
.
output_dir
,
exist_ok
=
True
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
# Set an xla_device flag on the model's config.
# Set an xla_device flag on the model's config.
...
@@ -236,6 +223,8 @@ class Trainer:
...
@@ -236,6 +223,8 @@ class Trainer:
),
),
FutureWarning
,
FutureWarning
,
)
)
self
.
global_step
=
None
self
.
epoch
=
None
if
self
.
args
.
fp16
and
_use_native_amp
:
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
self
.
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
...
@@ -333,39 +322,35 @@ class Trainer:
...
@@ -333,39 +322,35 @@ class Trainer:
drop_last
=
self
.
args
.
dataloader_drop_last
,
drop_last
=
self
.
args
.
dataloader_drop_last
,
)
)
def
get_optimizers
(
def
create_optimizer_and_scheduler
(
self
,
num_training_steps
:
int
):
self
,
num_training_steps
:
int
)
->
Tuple
[
torch
.
optim
.
Optimizer
,
torch
.
optim
.
lr_scheduler
.
LambdaLR
]:
"""
"""
Setup the optimizer and the learning rate scheduler.
Setup the optimizer and the learning rate scheduler.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
"""
if
self
.
optimizers
is
not
None
:
if
self
.
optimizer
is
None
:
return
self
.
optimizers
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
# Prepare optimizer and schedule (linear warmup and decay)
optimizer_grouped_parameters
=
[
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
{
optimizer_grouped_parameters
=
[
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
{
"weight_decay"
:
self
.
args
.
weight_decay
,
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
},
"weight_decay"
:
self
.
args
.
weight_decay
,
{
},
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
{
"weight_decay"
:
0.0
,
"params"
:
[
p
for
n
,
p
in
self
.
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
},
"weight_decay"
:
0.0
,
]
},
self
.
optimizer
=
AdamW
(
]
optimizer_grouped_parameters
,
optimizer
=
AdamW
(
lr
=
self
.
args
.
learning_rate
,
optimizer_grouped_parameters
,
betas
=
(
self
.
args
.
adam_beta1
,
self
.
args
.
adam_beta2
),
lr
=
self
.
args
.
learning_rate
,
eps
=
self
.
args
.
adam_epsilon
,
betas
=
(
self
.
args
.
adam_beta1
,
self
.
args
.
adam_beta2
),
)
eps
=
self
.
args
.
adam_epsilon
,
if
self
.
lr_scheduler
is
None
:
)
self
.
lr_scheduler
=
get_linear_schedule_with_warmup
(
scheduler
=
get_linear_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
)
return
optimizer
,
scheduler
def
setup_wandb
(
self
):
def
setup_wandb
(
self
):
"""
"""
...
@@ -390,7 +375,7 @@ class Trainer:
...
@@ -390,7 +375,7 @@ class Trainer:
)
)
return
self
.
_setup_wandb
()
return
self
.
_setup_wandb
()
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
logger
.
info
(
logger
.
info
(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
)
...
@@ -426,7 +411,7 @@ class Trainer:
...
@@ -426,7 +411,7 @@ class Trainer:
t_total
=
int
(
len
(
train_dataloader
)
//
self
.
args
.
gradient_accumulation_steps
*
self
.
args
.
num_train_epochs
)
t_total
=
int
(
len
(
train_dataloader
)
//
self
.
args
.
gradient_accumulation_steps
*
self
.
args
.
num_train_epochs
)
num_train_epochs
=
self
.
args
.
num_train_epochs
num_train_epochs
=
self
.
args
.
num_train_epochs
optimizer
,
scheduler
=
self
.
get_optimizers
(
num_training_steps
=
t_total
)
self
.
create_
optimizer
_and_
scheduler
(
num_training_steps
=
t_total
)
# Check if saved optimizer or scheduler states exist
# Check if saved optimizer or scheduler states exist
if
(
if
(
...
@@ -435,16 +420,16 @@ class Trainer:
...
@@ -435,16 +420,16 @@ class Trainer:
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
))
):
):
# Load in optimizer and scheduler states
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
torch
.
load
(
os
.
path
.
join
(
model_path
,
"optimizer.pt"
),
map_location
=
self
.
args
.
device
)
)
)
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
self
.
lr_
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
model_path
,
"scheduler.pt"
)))
model
=
self
.
model
model
=
self
.
model
if
self
.
args
.
fp16
and
_use_apex
:
if
self
.
args
.
fp16
and
_use_apex
:
if
not
is_apex_available
():
if
not
is_apex_available
():
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
self
.
args
.
fp16_opt_level
)
model
,
self
.
optimizer
=
amp
.
initialize
(
model
,
self
.
optimizer
,
opt_level
=
self
.
args
.
fp16_opt_level
)
# multi-gpu training (should be after apex fp16 initialization)
# multi-gpu training (should be after apex fp16 initialization)
if
self
.
args
.
n_gpu
>
1
:
if
self
.
args
.
n_gpu
>
1
:
...
@@ -506,7 +491,7 @@ class Trainer:
...
@@ -506,7 +491,7 @@ class Trainer:
logging_loss
=
0.0
logging_loss
=
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
train_iterator
=
trange
(
epochs_trained
,
int
(
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
not
self
.
is_local_
mast
er
()
epochs_trained
,
int
(
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
not
self
.
is_local_
process_z
er
o
()
)
)
for
epoch
in
train_iterator
:
for
epoch
in
train_iterator
:
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
if
isinstance
(
train_dataloader
,
DataLoader
)
and
isinstance
(
train_dataloader
.
sampler
,
DistributedSampler
):
...
@@ -516,9 +501,9 @@ class Trainer:
...
@@ -516,9 +501,9 @@ class Trainer:
parallel_loader
=
pl
.
ParallelLoader
(
train_dataloader
,
[
self
.
args
.
device
]).
per_device_loader
(
parallel_loader
=
pl
.
ParallelLoader
(
train_dataloader
,
[
self
.
args
.
device
]).
per_device_loader
(
self
.
args
.
device
self
.
args
.
device
)
)
epoch_iterator
=
tqdm
(
parallel_loader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
mast
er
())
epoch_iterator
=
tqdm
(
parallel_loader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
process_z
er
o
())
else
:
else
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
mast
er
())
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
not
self
.
is_local_
process_z
er
o
())
# Reset the past mems state at the beginning of each epoch if necessary.
# Reset the past mems state at the beginning of each epoch if necessary.
if
self
.
args
.
past_index
>=
0
:
if
self
.
args
.
past_index
>=
0
:
...
@@ -531,7 +516,7 @@ class Trainer:
...
@@ -531,7 +516,7 @@ class Trainer:
steps_trained_in_current_epoch
-=
1
steps_trained_in_current_epoch
-=
1
continue
continue
tr_loss
+=
self
.
training_step
(
model
,
inputs
,
optimizer
)
tr_loss
+=
self
.
training_step
(
model
,
inputs
)
if
(
step
+
1
)
%
self
.
args
.
gradient_accumulation_steps
==
0
or
(
if
(
step
+
1
)
%
self
.
args
.
gradient_accumulation_steps
==
0
or
(
# last step in epoch but step is always smaller than gradient_accumulation_steps
# last step in epoch but step is always smaller than gradient_accumulation_steps
...
@@ -539,23 +524,22 @@ class Trainer:
...
@@ -539,23 +524,22 @@ class Trainer:
and
(
step
+
1
)
==
len
(
epoch_iterator
)
and
(
step
+
1
)
==
len
(
epoch_iterator
)
):
):
if
self
.
args
.
fp16
and
_use_native_amp
:
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
unscale_
(
optimizer
)
self
.
scaler
.
unscale_
(
self
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
elif
self
.
args
.
fp16
and
_use_apex
:
elif
self
.
args
.
fp16
and
_use_apex
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
self
.
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
),
self
.
args
.
max_grad_norm
)
else
:
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
xm
.
optimizer_step
(
optimizer
)
xm
.
optimizer_step
(
self
.
optimizer
)
if
self
.
args
.
fp16
and
_use_native_amp
:
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
step
(
optimizer
)
self
.
scaler
.
step
(
self
.
optimizer
)
self
.
scaler
.
update
()
self
.
scaler
.
update
()
else
:
else
:
optimizer
.
step
()
self
.
optimizer
.
step
()
scheduler
.
step
()
self
.
lr_
scheduler
.
step
()
model
.
zero_grad
()
model
.
zero_grad
()
self
.
global_step
+=
1
self
.
global_step
+=
1
self
.
epoch
=
epoch
+
(
step
+
1
)
/
len
(
epoch_iterator
)
self
.
epoch
=
epoch
+
(
step
+
1
)
/
len
(
epoch_iterator
)
...
@@ -567,9 +551,9 @@ class Trainer:
...
@@ -567,9 +551,9 @@ class Trainer:
logs
[
"loss"
]
=
(
tr_loss
-
logging_loss
)
/
self
.
args
.
logging_steps
logs
[
"loss"
]
=
(
tr_loss
-
logging_loss
)
/
self
.
args
.
logging_steps
# backward compatibility for pytorch schedulers
# backward compatibility for pytorch schedulers
logs
[
"learning_rate"
]
=
(
logs
[
"learning_rate"
]
=
(
scheduler
.
get_last_lr
()[
0
]
self
.
lr_
scheduler
.
get_last_lr
()[
0
]
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.4"
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.4"
)
else
scheduler
.
get_lr
()[
0
]
else
self
.
lr_
scheduler
.
get_lr
()[
0
]
)
)
logging_loss
=
tr_loss
logging_loss
=
tr_loss
...
@@ -590,16 +574,16 @@ class Trainer:
...
@@ -590,16 +574,16 @@ class Trainer:
self
.
save_model
(
output_dir
)
self
.
save_model
(
output_dir
)
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
self
.
_rotate_checkpoints
()
self
.
_rotate_checkpoints
()
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
xm
.
rendezvous
(
"saving_optimizer_states"
)
xm
.
rendezvous
(
"saving_optimizer_states"
)
xm
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
xm
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
xm
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
xm
.
save
(
self
.
lr_
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
elif
self
.
is_world_
mast
er
():
elif
self
.
is_world_
process_z
er
o
():
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
torch
.
save
(
self
.
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
torch
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
torch
.
save
(
self
.
lr_
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
>
self
.
args
.
max_steps
:
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
>
self
.
args
.
max_steps
:
epoch_iterator
.
close
()
epoch_iterator
.
close
()
...
@@ -660,7 +644,7 @@ class Trainer:
...
@@ -660,7 +644,7 @@ class Trainer:
)
)
self
.
tb_writer
.
flush
()
self
.
tb_writer
.
flush
()
if
is_wandb_available
():
if
is_wandb_available
():
if
self
.
is_world_
mast
er
():
if
self
.
is_world_
process_z
er
o
():
wandb
.
log
(
logs
,
step
=
self
.
global_step
)
wandb
.
log
(
logs
,
step
=
self
.
global_step
)
output
=
{
**
logs
,
**
{
"step"
:
self
.
global_step
}}
output
=
{
**
logs
,
**
{
"step"
:
self
.
global_step
}}
if
iterator
is
not
None
:
if
iterator
is
not
None
:
...
@@ -684,11 +668,9 @@ class Trainer:
...
@@ -684,11 +668,9 @@ class Trainer:
return
inputs
return
inputs
def
training_step
(
def
training_step
(
self
,
model
:
nn
.
Module
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]])
->
float
:
self
,
model
:
nn
.
Module
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
optimizer
:
torch
.
optim
.
Optimizer
)
->
float
:
"""
"""
Perform a training step on
:obj:`model` using obj:`inputs` and :obj:`optimizer`
.
Perform a training step on
a batch of inputs
.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
...
@@ -700,19 +682,16 @@ class Trainer:
...
@@ -700,19 +682,16 @@ class Trainer:
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer to use to make a step.
Return:
Return:
`float`:
:obj:`float`: The training loss on this batch.
The training loss on this batch.
"""
"""
if
hasattr
(
self
,
"_training_step"
):
if
hasattr
(
self
,
"_training_step"
):
warnings
.
warn
(
warnings
.
warn
(
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass."
,
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass."
,
FutureWarning
,
FutureWarning
,
)
)
return
self
.
_training_step
(
model
,
inputs
,
optimizer
)
return
self
.
_training_step
(
model
,
inputs
,
self
.
optimizer
)
model
.
train
()
model
.
train
()
inputs
=
self
.
_prepare_inputs
(
inputs
,
model
)
inputs
=
self
.
_prepare_inputs
(
inputs
,
model
)
...
@@ -738,7 +717,7 @@ class Trainer:
...
@@ -738,7 +717,7 @@ class Trainer:
if
self
.
args
.
fp16
and
_use_native_amp
:
if
self
.
args
.
fp16
and
_use_native_amp
:
self
.
scaler
.
scale
(
loss
).
backward
()
self
.
scaler
.
scale
(
loss
).
backward
()
elif
self
.
args
.
fp16
and
_use_apex
:
elif
self
.
args
.
fp16
and
_use_apex
:
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
scaled_loss
.
backward
()
else
:
else
:
loss
.
backward
()
loss
.
backward
()
...
@@ -746,6 +725,22 @@ class Trainer:
...
@@ -746,6 +725,22 @@ class Trainer:
return
loss
.
item
()
return
loss
.
item
()
def
is_local_master
(
self
)
->
bool
:
def
is_local_master
(
self
)
->
bool
:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
"""
warnings
.
warn
(
"This method is deprecated, use `Trainer.is_local_process_zero()` instead."
,
FutureWarning
)
return
self
.
is_local_process_zero
()
def
is_local_process_zero
(
self
)
->
bool
:
"""
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
several machines) main process.
"""
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
return
xm
.
is_master_ordinal
(
local
=
True
)
return
xm
.
is_master_ordinal
(
local
=
True
)
else
:
else
:
...
@@ -753,8 +748,20 @@ class Trainer:
...
@@ -753,8 +748,20 @@ class Trainer:
def
is_world_master
(
self
)
->
bool
:
def
is_world_master
(
self
)
->
bool
:
"""
"""
This will be True only in one process, even in distributed mode,
Whether or not this process is the global main process (when training in a distributed fashion on
even when training on multiple machines.
several machines, this is only going to be :obj:`True` for one process).
.. warning::
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
"""
warnings
.
warn
(
"This method is deprecated, use `Trainer.is_world_process_zero()` instead."
,
FutureWarning
)
return
self
.
is_world_process_zero
()
def
is_world_process_zero
(
self
)
->
bool
:
"""
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
"""
"""
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
return
xm
.
is_master_ordinal
(
local
=
False
)
return
xm
.
is_master_ordinal
(
local
=
False
)
...
@@ -770,7 +777,7 @@ class Trainer:
...
@@ -770,7 +777,7 @@ class Trainer:
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
self
.
_save_tpu
(
output_dir
)
self
.
_save_tpu
(
output_dir
)
elif
self
.
is_world_
mast
er
():
elif
self
.
is_world_
process_z
er
o
():
self
.
_save
(
output_dir
)
self
.
_save
(
output_dir
)
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
def
_save_tpu
(
self
,
output_dir
:
Optional
[
str
]
=
None
):
...
@@ -846,6 +853,7 @@ class Trainer:
...
@@ -846,6 +853,7 @@ class Trainer:
Args:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
"""
...
@@ -871,6 +879,7 @@ class Trainer:
...
@@ -871,6 +879,7 @@ class Trainer:
Args:
Args:
test_dataset (:obj:`Dataset`):
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on.
Dataset to run the predictions on.
Returns:
Returns:
`NamedTuple`:
`NamedTuple`:
predictions (:obj:`np.ndarray`):
predictions (:obj:`np.ndarray`):
...
...
src/transformers/trainer_tf.py
View file @
86caab1e
...
@@ -63,17 +63,6 @@ class TFTrainer:
...
@@ -63,17 +63,6 @@ class TFTrainer:
an instance of :class:`~transformers.WarmUp`.
an instance of :class:`~transformers.WarmUp`.
"""
"""
model
:
TFPreTrainedModel
args
:
TFTrainingArguments
train_dataset
:
Optional
[
tf
.
data
.
Dataset
]
eval_dataset
:
Optional
[
tf
.
data
.
Dataset
]
compute_metrics
:
Optional
[
Callable
[[
EvalPrediction
],
Dict
]]
=
None
prediction_loss_only
:
bool
tb_writer
:
Optional
[
tf
.
summary
.
SummaryWriter
]
=
None
optimizers
:
Tuple
[
tf
.
keras
.
optimizers
.
Optimizer
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
]
=
(
None
,
None
)
global_step
:
Optional
[
int
]
=
None
epoch_logging
:
Optional
[
float
]
=
None
def
__init__
(
def
__init__
(
self
,
self
,
model
:
TFPreTrainedModel
,
model
:
TFPreTrainedModel
,
...
@@ -325,6 +314,15 @@ class TFTrainer:
...
@@ -325,6 +314,15 @@ class TFTrainer:
return
PredictionOutput
(
predictions
=
preds
,
label_ids
=
label_ids
,
metrics
=
metrics
)
return
PredictionOutput
(
predictions
=
preds
,
label_ids
=
label_ids
,
metrics
=
metrics
)
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
"""
Log :obj:`logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
"""
if
hasattr
(
self
,
"_log"
):
if
hasattr
(
self
,
"_log"
):
warnings
.
warn
(
warnings
.
warn
(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass."
,
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass."
,
...
@@ -356,6 +354,7 @@ class TFTrainer:
...
@@ -356,6 +354,7 @@ class TFTrainer:
Args:
Args:
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
eval_dataset (:class:`~tf.data.Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Pass a dataset if you wish to override :obj:`self.eval_dataset`.
Returns:
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
"""
...
@@ -577,9 +576,12 @@ class TFTrainer:
...
@@ -577,9 +576,12 @@ class TFTrainer:
Subclass and override this method if you want to inject some custom behavior.
Subclass and override this method if you want to inject some custom behavior.
Args:
Args:
features: the batched features.
features (:obj:`tf.Tensor`): A batch of input features.
labels: the batched labels.
labels (:obj:`tf.Tensor`): A batch of labels.
training: run the model in training mode or not
training (:obj:`bool`): Whether or not to run the model in training mode.
Returns:
A tuple of two :obj:`tf.Tensor`: The loss and logits.
"""
"""
if
hasattr
(
self
,
"_run_model"
):
if
hasattr
(
self
,
"_run_model"
):
warnings
.
warn
(
warnings
.
warn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment