Commit f0141859 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Add condition to avoid saving duplicate checkpoints.

PiperOrigin-RevId: 277156216
parent 803b2540
......@@ -452,7 +452,6 @@ class DistributedExecutor(object):
save_freq = FLAGS.save_checkpoint_freq
else:
save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params
strategy = self._strategy
......@@ -507,6 +506,7 @@ class DistributedExecutor(object):
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started')
last_save_checkpoint_step = current_step
while current_step < total_steps:
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
......@@ -569,8 +569,9 @@ class DistributedExecutor(object):
train_metric.reset_states()
# Reaches the end of training and saves the last checkpoint.
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if last_save_checkpoint_step < total_steps:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_step:
logging.info('Running final evaluation after training is complete.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment