Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7cb52f53
Unverified
Commit
7cb52f53
authored
Jun 29, 2020
by
Julien Plu
Committed by
GitHub
Jun 29, 2020
Browse files
Fix LR decay in TF Trainer (#5269)
* Recover old PR * Apply style * Trigger CI
parent
321c05ab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
14 deletions
+61
-14
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+61
-14
No files found.
src/transformers/trainer_tf.py
View file @
7cb52f53
...
...
@@ -3,6 +3,7 @@
import
logging
import
math
import
os
import
random
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
import
numpy
as
np
...
...
@@ -21,6 +22,12 @@ if is_wandb_available():
logger
=
logging
.
getLogger
(
__name__
)
def
set_seed
(
seed
:
int
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
tf
.
random
.
set_seed
(
seed
)
class
TFTrainer
:
model
:
TFPreTrainedModel
args
:
TFTrainingArguments
...
...
@@ -59,6 +66,7 @@ class TFTrainer:
self
.
tb_writer
=
tb_writer
else
:
self
.
tb_writer
=
tf
.
summary
.
create_file_writer
(
self
.
args
.
logging_dir
)
if
is_wandb_available
():
self
.
_setup_wandb
()
else
:
...
...
@@ -67,6 +75,8 @@ class TFTrainer:
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
)
set_seed
(
self
.
args
.
seed
)
def
get_train_tfdataset
(
self
)
->
tf
.
data
.
Dataset
:
if
self
.
train_dataset
is
None
:
raise
ValueError
(
"Trainer: training requires a train_dataset."
)
...
...
@@ -109,7 +119,7 @@ class TFTrainer:
return
self
.
args
.
strategy
.
experimental_distribute_dataset
(
ds
)
def
get_optimizers
(
self
,
self
,
num_training_steps
:
int
,
)
->
Tuple
[
tf
.
keras
.
optimizers
.
Optimizer
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
]:
"""
Setup the optimizer and the learning rate scheduler.
...
...
@@ -123,7 +133,7 @@ class TFTrainer:
optimizer
,
scheduler
=
create_optimizer
(
self
.
args
.
learning_rate
,
self
.
train_steps
,
num_
train
ing
_steps
,
self
.
args
.
warmup_steps
,
adam_epsilon
=
self
.
args
.
adam_epsilon
,
weight_decay_rate
=
self
.
args
.
weight_decay
,
...
...
@@ -238,14 +248,19 @@ class TFTrainer:
return
PredictionOutput
(
predictions
=
preds
,
label_ids
=
label_ids
,
metrics
=
metrics
)
def
_log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
logs
[
"epoch"
]
=
self
.
epoch_logging
if
self
.
tb_writer
:
with
self
.
tb_writer
.
as_default
():
for
k
,
v
in
logs
.
items
():
tf
.
summary
.
scalar
(
k
,
v
,
step
=
self
.
global_step
)
self
.
tb_writer
.
flush
()
if
is_wandb_available
():
wandb
.
log
(
logs
,
step
=
self
.
global_step
)
output
=
{
**
logs
,
**
{
"step"
:
self
.
global_step
}}
logger
.
info
(
output
)
def
evaluate
(
...
...
@@ -260,6 +275,7 @@ class TFTrainer:
logs
=
{
**
output
.
metrics
}
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_log
(
logs
)
return
output
.
metrics
...
...
@@ -275,25 +291,45 @@ class TFTrainer:
self
.
gradient_accumulator
.
reset
()
if
self
.
args
.
max_steps
>
0
:
t_total
=
self
.
args
.
max_steps
steps_per_epoch
=
self
.
args
.
max_steps
else
:
if
self
.
args
.
dataloader_drop_last
:
approx
=
math
.
floor
else
:
approx
=
math
.
ceil
steps_per_epoch
=
approx
(
self
.
num_train_examples
/
(
self
.
args
.
train_batch_size
*
self
.
args
.
gradient_accumulation_steps
)
)
t_total
=
steps_per_epoch
*
self
.
args
.
num_train_epochs
with
self
.
args
.
strategy
.
scope
():
optimizer
,
lr_scheduler
=
self
.
get_optimizers
()
optimizer
,
lr_scheduler
=
self
.
get_optimizers
(
num_training_steps
=
t_total
)
iterations
=
optimizer
.
iterations
self
.
global_step
=
iterations
.
numpy
()
folder
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
PREFIX_CHECKPOINT_DIR
)
ckpt
=
tf
.
train
.
Checkpoint
(
optimizer
=
optimizer
,
model
=
self
.
model
)
self
.
model
.
ckpt_manager
=
tf
.
train
.
CheckpointManager
(
ckpt
,
folder
,
max_to_keep
=
self
.
args
.
save_total_limit
)
if
self
.
model
.
ckpt_manager
.
latest_checkpoint
:
epochs_trained
=
self
.
global_step
//
(
self
.
num_train_examples
//
self
.
args
.
gradient_accumulation_steps
)
steps_trained_in_current_epoch
=
self
.
global_step
%
(
self
.
num_train_examples
//
self
.
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
"Checkpoint file %s found and restoring from checkpoint"
,
self
.
model
.
ckpt_manager
.
latest_checkpoint
)
ckpt
.
restore
(
self
.
model
.
ckpt_manager
.
latest_checkpoint
).
expect_partial
()
if
iterations
.
numpy
()
>
0
:
logger
.
info
(
"Start the training from the last checkpoint"
)
start_epoch
=
(
iterations
.
numpy
()
//
self
.
train_steps
)
+
1
else
:
start_epoch
=
1
else
:
epochs_trained
=
1
tf
.
summary
.
experimental
.
set_step
(
iterations
)
...
...
@@ -311,17 +347,23 @@ class TFTrainer:
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
self
.
num_train_examples
)
logger
.
info
(
" Num Epochs = %d"
,
epochs
)
logger
.
info
(
" Total optimization steps = %d"
,
self
.
train_steps
)
logger
.
info
(
" Instantaneous batch size per device = %d"
,
self
.
args
.
per_device_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
self
.
args
.
train_batch_size
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
self
.
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
for
epoch_iter
in
range
(
start_epoch
,
int
(
epochs
+
1
)):
for
epoch_iter
in
range
(
epochs_trained
,
int
(
epochs
+
1
)):
for
step
,
training_loss
in
enumerate
(
self
.
_training_steps
(
train_ds
,
optimizer
)):
self
.
global_step
=
iterations
.
numpy
()
self
.
epoch_logging
=
epoch_iter
-
1
+
(
step
+
1
)
/
s
elf
.
train_steps
self
.
epoch_logging
=
epoch_iter
-
1
+
(
step
+
1
)
/
s
teps_per_epoch
if
self
.
args
.
debug
:
logs
=
{}
logs
[
"loss"
]
=
training_loss
.
numpy
()
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_log
(
logs
)
if
self
.
global_step
==
1
and
self
.
args
.
debug
:
...
...
@@ -333,18 +375,23 @@ class TFTrainer:
if
self
.
args
.
evaluate_during_training
and
self
.
global_step
%
self
.
args
.
eval_steps
==
0
:
self
.
evaluate
()
if
self
.
global_step
%
self
.
args
.
logging_steps
==
0
:
if
(
self
.
global_step
%
self
.
args
.
logging_steps
==
0
or
self
.
global_step
==
1
and
self
.
args
.
logging_first_step
):
logs
=
{}
logs
[
"loss"
]
=
training_loss
.
numpy
()
logs
[
"learning_rate"
]
=
lr_scheduler
(
self
.
global_step
).
numpy
()
logs
[
"epoch"
]
=
self
.
epoch_logging
self
.
_log
(
logs
)
if
self
.
global_step
%
self
.
args
.
save_steps
==
0
:
ckpt_save_path
=
self
.
model
.
ckpt_manager
.
save
()
logger
.
info
(
"Saving checkpoint for step {} at {}"
.
format
(
self
.
global_step
,
ckpt_save_path
))
if
self
.
global_step
%
self
.
train
_steps
==
0
:
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
%
self
.
args
.
max
_steps
==
0
:
break
def
_training_steps
(
self
,
ds
,
optimizer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment