Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d0ef3913
Commit
d0ef3913
authored
Jun 16, 2020
by
A. Unique TensorFlower
Browse files
Update model_training_utils for BERT.
PiperOrigin-RevId: 316801831
parent
6da061c0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
+11
-6
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+11
-6
No files found.
official/nlp/bert/model_training_utils.py
View file @
d0ef3913
...
@@ -160,9 +160,10 @@ def run_customized_training_loop(
...
@@ -160,9 +160,10 @@ def run_customized_training_loop(
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
`model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
training. More specifically, `on_train_begin(), on_train_end(),
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during
on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
training. Note that some metrics may be missing from `logs`.
`on_epoch_end()` methods are invoked during training.
Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This
run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy.
should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by
sub_model_export_name: If not None, will export `sub_model` returned by
...
@@ -246,8 +247,6 @@ def run_customized_training_loop(
...
@@ -246,8 +247,6 @@ def run_customized_training_loop(
raise
ValueError
(
raise
ValueError
(
'if `metric_fn` is specified, metric_fn must be a callable.'
)
'if `metric_fn` is specified, metric_fn must be a callable.'
)
callback_list
=
tf
.
keras
.
callbacks
.
CallbackList
(
custom_callbacks
)
total_training_steps
=
steps_per_epoch
*
epochs
total_training_steps
=
steps_per_epoch
*
epochs
train_iterator
=
_get_input_iterator
(
train_input_fn
,
strategy
)
train_iterator
=
_get_input_iterator
(
train_input_fn
,
strategy
)
eval_loss_metric
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
eval_loss_metric
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
...
@@ -263,6 +262,9 @@ def run_customized_training_loop(
...
@@ -263,6 +262,9 @@ def run_customized_training_loop(
raise
ValueError
(
'sub_model_export_name is specified as %s, but '
raise
ValueError
(
'sub_model_export_name is specified as %s, but '
'sub_model is None.'
%
sub_model_export_name
)
'sub_model is None.'
%
sub_model_export_name
)
callback_list
=
tf
.
keras
.
callbacks
.
CallbackList
(
callbacks
=
custom_callbacks
,
model
=
model
)
optimizer
=
model
.
optimizer
optimizer
=
model
.
optimizer
if
init_checkpoint
:
if
init_checkpoint
:
...
@@ -451,7 +453,8 @@ def run_customized_training_loop(
...
@@ -451,7 +453,8 @@ def run_customized_training_loop(
checkpoint_name
=
'ctl_step_{step}.ckpt'
checkpoint_name
=
'ctl_step_{step}.ckpt'
logs
=
{}
logs
=
{}
while
current_step
<
total_training_steps
:
callback_list
.
on_train_begin
()
while
current_step
<
total_training_steps
and
not
model
.
stop_training
:
if
current_step
%
steps_per_epoch
==
0
:
if
current_step
%
steps_per_epoch
==
0
:
callback_list
.
on_epoch_begin
(
callback_list
.
on_epoch_begin
(
int
(
current_step
/
steps_per_epoch
)
+
1
)
int
(
current_step
/
steps_per_epoch
)
+
1
)
...
@@ -564,4 +567,6 @@ def run_customized_training_loop(
...
@@ -564,4 +567,6 @@ def run_customized_training_loop(
if
not
_should_export_summary
(
strategy
):
if
not
_should_export_summary
(
strategy
):
tf
.
io
.
gfile
.
rmtree
(
summary_dir
)
tf
.
io
.
gfile
.
rmtree
(
summary_dir
)
callback_list
.
on_train_end
()
return
model
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment