Commit 8931d298 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Remove extraneous evaluation & export which happens on the final epoch.

Preserves current behavior for submodels of exporting duplicates (e.g., for 100 total steps, my_submodel_step_100.ckpt and my_submodel.ckpt are equivalent and both written after the final epoch).

PiperOrigin-RevId: 306629202
parent 47d10833
...@@ -443,31 +443,34 @@ def run_customized_training_loop( ...@@ -443,31 +443,34 @@ def run_customized_training_loop(
train_summary_writer.flush() train_summary_writer.flush()
logging.info(training_status) logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
# To avoid repeated model saving, we do not save after the last # Save a submodel with the step in the file name after each epoch.
# step of training. if sub_model_export_name:
_save_checkpoint(
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
# Save model checkpoints and run validation steps after each epoch
# (with the exception of the final epoch which is handled after the
# training loop).
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(strategy, checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: if eval_input_fn:
_save_checkpoint( logging.info('Running evaluation after step: %s.', current_step)
strategy, sub_model_checkpoint, model_dir, _run_evaluation(current_step,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step)) _get_input_iterator(eval_input_fn, strategy))
if eval_input_fn: # Re-initialize evaluation metric.
logging.info('Running evaluation after step: %s.', current_step) for metric in eval_metrics + model.metrics:
_run_evaluation(current_step, metric.reset_states()
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(strategy, sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name) '%s.ckpt' % sub_model_export_name)
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
_run_evaluation(current_step, _run_evaluation(current_step,
......
...@@ -156,6 +156,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -156,6 +156,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
eval_input_fn=input_fn, eval_input_fn=input_fn,
eval_steps=10, eval_steps=10,
init_checkpoint=None, init_checkpoint=None,
sub_model_export_name='my_submodel_name',
metric_fn=metric_fn, metric_fn=metric_fn,
custom_callbacks=None, custom_callbacks=None,
run_eagerly=run_eagerly) run_eagerly=run_eagerly)
...@@ -188,7 +189,20 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -188,7 +189,20 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
distribution, model_dir, steps_per_loop=10, run_eagerly=False) distribution, model_dir, steps_per_loop=10, run_eagerly=False)
# Two checkpoints should be saved after two epochs. # Two checkpoints should be saved after two epochs.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*'))) files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*index')))
self.assertCountEqual(['ctl_step_20.ckpt-1.index',
'ctl_step_40.ckpt-2.index'], files)
# Three submodel checkpoints should be saved after two epochs (one after
# each epoch plus one final).
files = map(os.path.basename,
tf.io.gfile.glob(os.path.join(model_dir,
'my_submodel_name*index')))
self.assertCountEqual(['my_submodel_name.ckpt-3.index',
'my_submodel_name_step_20.ckpt-1.index',
'my_submodel_name_step_40.ckpt-2.index'], files)
self.assertNotEmpty( self.assertNotEmpty(
tf.io.gfile.glob( tf.io.gfile.glob(
os.path.join(model_dir, 'summaries/training_summary*'))) os.path.join(model_dir, 'summaries/training_summary*')))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment