Commit d0ef3913 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Update model_training_utils for BERT.

PiperOrigin-RevId: 316801831
parent 6da061c0
...@@ -160,9 +160,10 @@ def run_customized_training_loop( ...@@ -160,9 +160,10 @@ def run_customized_training_loop(
init_checkpoint: Optional checkpoint to load to `sub_model` returned by init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`. `model_fn`.
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_train_begin(), on_train_end(),
`on_epoch_begin()`, `on_epoch_end()` methods are invoked during on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
training. Note that some metrics may be missing from `logs`. `on_epoch_end()` methods are invoked during training.
Note that some metrics may be missing from `logs`.
run_eagerly: Whether to run model training in pure eager execution. This run_eagerly: Whether to run model training in pure eager execution. This
should be disable for TPUStrategy. should be disable for TPUStrategy.
sub_model_export_name: If not None, will export `sub_model` returned by sub_model_export_name: If not None, will export `sub_model` returned by
...@@ -246,8 +247,6 @@ def run_customized_training_loop( ...@@ -246,8 +247,6 @@ def run_customized_training_loop(
raise ValueError( raise ValueError(
'if `metric_fn` is specified, metric_fn must be a callable.') 'if `metric_fn` is specified, metric_fn must be a callable.')
callback_list = tf.keras.callbacks.CallbackList(custom_callbacks)
total_training_steps = steps_per_epoch * epochs total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy) train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) eval_loss_metric = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
...@@ -263,6 +262,9 @@ def run_customized_training_loop( ...@@ -263,6 +262,9 @@ def run_customized_training_loop(
raise ValueError('sub_model_export_name is specified as %s, but ' raise ValueError('sub_model_export_name is specified as %s, but '
'sub_model is None.' % sub_model_export_name) 'sub_model is None.' % sub_model_export_name)
callback_list = tf.keras.callbacks.CallbackList(
callbacks=custom_callbacks, model=model)
optimizer = model.optimizer optimizer = model.optimizer
if init_checkpoint: if init_checkpoint:
...@@ -451,7 +453,8 @@ def run_customized_training_loop( ...@@ -451,7 +453,8 @@ def run_customized_training_loop(
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
logs = {} logs = {}
while current_step < total_training_steps: callback_list.on_train_begin()
while current_step < total_training_steps and not model.stop_training:
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
callback_list.on_epoch_begin( callback_list.on_epoch_begin(
int(current_step / steps_per_epoch) + 1) int(current_step / steps_per_epoch) + 1)
...@@ -564,4 +567,6 @@ def run_customized_training_loop( ...@@ -564,4 +567,6 @@ def run_customized_training_loop(
if not _should_export_summary(strategy): if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir) tf.io.gfile.rmtree(summary_dir)
callback_list.on_train_end()
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment