Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3a03bab9
Unverified
Commit
3a03bab9
authored
Sep 18, 2020
by
Yih-Dar
Committed by
GitHub
Sep 18, 2020
Browse files
Fix a few countings (steps / epochs) in trainer_tf.py (#7175)
parent
ee9eae4e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
25 deletions
+53
-25
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+53
-25
No files found.
src/transformers/trainer_tf.py
View file @
3a03bab9
...
@@ -478,43 +478,58 @@ class TFTrainer:
...
@@ -478,43 +478,58 @@ class TFTrainer:
self
.
gradient_accumulator
.
reset
()
self
.
gradient_accumulator
.
reset
()
num_update_steps_per_epoch
=
self
.
num_train_examples
/
self
.
total_train_batch_size
# In fact, ``self.args.dataloader_drop_last`` has no effect in `trainer_tf.py`, because
# the dataset is repeated before being batched.
# It has the effect only when TPU is used which requires explicit tensor shape in order to make
# the gradient accumulation implementation work.
approx
=
math
.
floor
if
self
.
args
.
dataloader_drop_last
else
math
.
ceil
num_update_steps_per_epoch
=
approx
(
num_update_steps_per_epoch
)
# At least one update for each epoch.
num_update_steps_per_epoch
=
max
(
num_update_steps_per_epoch
,
1
)
self
.
steps_per_epoch
=
num_update_steps_per_epoch
if
self
.
args
.
max_steps
>
0
:
if
self
.
args
.
max_steps
>
0
:
t_total
=
self
.
args
.
max_steps
t_total
=
self
.
args
.
max_steps
self
.
steps_per_epoch
=
self
.
args
.
max_steps
epochs
=
(
self
.
args
.
max_steps
//
self
.
steps_per_epoch
)
+
int
(
self
.
args
.
max_steps
%
self
.
steps_per_epoch
>
0
)
else
:
else
:
approx
=
math
.
floor
if
self
.
args
.
dataloader_drop_last
else
math
.
ceil
self
.
steps_per_epoch
=
approx
(
self
.
num_train_examples
/
self
.
total_train_batch_size
)
t_total
=
self
.
steps_per_epoch
*
self
.
args
.
num_train_epochs
t_total
=
self
.
steps_per_epoch
*
self
.
args
.
num_train_epochs
epochs
=
self
.
args
.
num_train_epochs
# Since ``self.args.num_train_epochs`` can be `float`, we make ``epochs`` be a `float` always.
epochs
=
float
(
epochs
)
with
self
.
args
.
strategy
.
scope
():
with
self
.
args
.
strategy
.
scope
():
self
.
create_optimizer_and_scheduler
(
num_training_steps
=
t_total
)
self
.
create_optimizer_and_scheduler
(
num_training_steps
=
t_total
)
iterations
=
self
.
optimizer
.
iterations
self
.
global_step
=
iterations
.
numpy
()
folder
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
PREFIX_CHECKPOINT_DIR
)
folder
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
PREFIX_CHECKPOINT_DIR
)
ckpt
=
tf
.
train
.
Checkpoint
(
optimizer
=
self
.
optimizer
,
model
=
self
.
model
)
ckpt
=
tf
.
train
.
Checkpoint
(
optimizer
=
self
.
optimizer
,
model
=
self
.
model
)
self
.
model
.
ckpt_manager
=
tf
.
train
.
CheckpointManager
(
ckpt
,
folder
,
max_to_keep
=
self
.
args
.
save_total_limit
)
self
.
model
.
ckpt_manager
=
tf
.
train
.
CheckpointManager
(
ckpt
,
folder
,
max_to_keep
=
self
.
args
.
save_total_limit
)
iterations
=
self
.
optimizer
.
iterations
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
if
self
.
model
.
ckpt_manager
.
latest_checkpoint
:
if
self
.
model
.
ckpt_manager
.
latest_checkpoint
:
epochs_trained
=
self
.
global_step
//
(
self
.
num_train_examples
//
self
.
args
.
gradient_accumulation_steps
)
steps_trained_in_current_epoch
=
self
.
global_step
%
(
self
.
num_train_examples
//
self
.
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
logger
.
info
(
logger
.
info
(
"Checkpoint file %s found and restoring from checkpoint"
,
self
.
model
.
ckpt_manager
.
latest_checkpoint
"Checkpoint file %s found and restoring from checkpoint"
,
self
.
model
.
ckpt_manager
.
latest_checkpoint
)
)
ckpt
.
restore
(
self
.
model
.
ckpt_manager
.
latest_checkpoint
).
expect_partial
()
ckpt
.
restore
(
self
.
model
.
ckpt_manager
.
latest_checkpoint
).
expect_partial
()
else
:
epochs_trained
=
1
tf
.
summary
.
experimental
.
set
_step
(
iterations
)
self
.
global
_step
=
iterations
.
numpy
(
)
epochs
=
1
if
self
.
args
.
max_steps
>
0
else
self
.
args
.
num_train_epochs
epochs_trained
=
self
.
global_step
//
self
.
steps_per_epoch
steps_trained_in_current_epoch
=
self
.
global_step
%
self
.
steps_per_epoch
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
self
.
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
if
self
.
args
.
fp16
:
if
self
.
args
.
fp16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
)
...
@@ -527,6 +542,7 @@ class TFTrainer:
...
@@ -527,6 +542,7 @@ class TFTrainer:
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
self
.
num_train_examples
)
logger
.
info
(
" Num examples = %d"
,
self
.
num_train_examples
)
# TODO: We might want to print a more precise ``epochs`` if self.args.max_steps > 0 ?
logger
.
info
(
" Num Epochs = %d"
,
epochs
)
logger
.
info
(
" Num Epochs = %d"
,
epochs
)
logger
.
info
(
" Instantaneous batch size per device = %d"
,
self
.
args
.
per_device_train_batch_size
)
logger
.
info
(
" Instantaneous batch size per device = %d"
,
self
.
args
.
per_device_train_batch_size
)
logger
.
info
(
logger
.
info
(
...
@@ -539,17 +555,23 @@ class TFTrainer:
...
@@ -539,17 +555,23 @@ class TFTrainer:
self
.
train_loss
=
tf
.
keras
.
metrics
.
Sum
()
self
.
train_loss
=
tf
.
keras
.
metrics
.
Sum
()
start_time
=
datetime
.
datetime
.
now
()
start_time
=
datetime
.
datetime
.
now
()
for
epoch_iter
in
range
(
epochs_trained
,
int
(
epochs
+
1
)):
for
epoch_iter
in
range
(
epochs_trained
,
int
(
epochs
)):
# Reset the past mems state at the beginning of each epoch if necessary.
# Reset the past mems state at the beginning of each epoch if necessary.
if
self
.
args
.
past_index
>=
0
:
if
self
.
args
.
past_index
>=
0
:
self
.
_past
=
None
self
.
_past
=
None
for
step
,
batch
in
enumerate
(
train_ds
):
for
step
,
batch
in
enumerate
(
train_ds
):
self
.
global_step
=
iterations
.
numpy
()
self
.
epoch_logging
=
epoch_iter
-
1
+
(
step
+
1
)
/
self
.
steps_per_epoch
# Skip past any already trained steps if resuming training
if
steps_trained_in_current_epoch
>
0
:
steps_trained_in_current_epoch
-=
1
continue
self
.
distributed_training_steps
(
batch
)
self
.
distributed_training_steps
(
batch
)
self
.
global_step
=
iterations
.
numpy
()
self
.
epoch_logging
=
epoch_iter
+
(
step
+
1
)
/
self
.
steps_per_epoch
training_loss
=
self
.
train_loss
.
result
()
/
(
step
+
1
)
training_loss
=
self
.
train_loss
.
result
()
/
(
step
+
1
)
if
self
.
args
.
debug
:
if
self
.
args
.
debug
:
...
@@ -566,13 +588,13 @@ class TFTrainer:
...
@@ -566,13 +588,13 @@ class TFTrainer:
)
)
if
(
if
(
self
.
glob
al_step
>
0
self
.
args
.
ev
al_step
s
>
0
and
self
.
args
.
evaluate_during_training
and
self
.
args
.
evaluate_during_training
and
self
.
global_step
%
self
.
args
.
eval_steps
==
0
and
self
.
global_step
%
self
.
args
.
eval_steps
==
0
):
):
self
.
evaluate
()
self
.
evaluate
()
if
(
self
.
global
_step
>
0
and
self
.
global_step
%
self
.
args
.
logging_steps
==
0
)
or
(
if
(
self
.
args
.
logging
_step
s
>
0
and
self
.
global_step
%
self
.
args
.
logging_steps
==
0
)
or
(
self
.
global_step
==
1
and
self
.
args
.
logging_first_step
self
.
global_step
==
1
and
self
.
args
.
logging_first_step
):
):
logs
=
{}
logs
=
{}
...
@@ -582,16 +604,22 @@ class TFTrainer:
...
@@ -582,16 +604,22 @@ class TFTrainer:
self
.
log
(
logs
)
self
.
log
(
logs
)
if
self
.
global
_step
>
0
and
self
.
global_step
%
self
.
args
.
save_steps
==
0
:
if
self
.
args
.
save
_step
s
>
0
and
self
.
global_step
%
self
.
args
.
save_steps
==
0
:
ckpt_save_path
=
self
.
model
.
ckpt_manager
.
save
()
ckpt_save_path
=
self
.
model
.
ckpt_manager
.
save
()
logger
.
info
(
"Saving checkpoint for step {} at {}"
.
format
(
self
.
global_step
,
ckpt_save_path
))
logger
.
info
(
"Saving checkpoint for step {} at {}"
.
format
(
self
.
global_step
,
ckpt_save_path
))
if
self
.
global_step
>
0
and
self
.
global_step
%
self
.
steps_per_epoch
==
0
:
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
>=
t_total
:
break
if
self
.
global_step
%
self
.
steps_per_epoch
==
0
:
break
break
self
.
train_loss
.
reset_states
()
self
.
train_loss
.
reset_states
()
if
self
.
args
.
max_steps
>
0
and
self
.
global_step
>=
self
.
args
.
max_steps
:
break
end_time
=
datetime
.
datetime
.
now
()
end_time
=
datetime
.
datetime
.
now
()
logger
.
info
(
"Training took: {}"
.
format
(
str
(
end_time
-
start_time
)))
logger
.
info
(
"Training took: {}"
.
format
(
str
(
end_time
-
start_time
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment