Commit f0141859 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Add condition to avoid saving duplicate checkpoints.

PiperOrigin-RevId: 277156216
parent 803b2540
...@@ -452,7 +452,6 @@ class DistributedExecutor(object): ...@@ -452,7 +452,6 @@ class DistributedExecutor(object):
save_freq = FLAGS.save_checkpoint_freq save_freq = FLAGS.save_checkpoint_freq
else: else:
save_freq = iterations_per_loop save_freq = iterations_per_loop
last_save_checkpoint_step = 0
params = self._params params = self._params
strategy = self._strategy strategy = self._strategy
...@@ -507,6 +506,7 @@ class DistributedExecutor(object): ...@@ -507,6 +506,7 @@ class DistributedExecutor(object):
test_step = self._create_test_step(strategy, model, metric=eval_metric) test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started') logging.info('Training started')
last_save_checkpoint_step = current_step
while current_step < total_steps: while current_step < total_steps:
num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop) num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
...@@ -569,6 +569,7 @@ class DistributedExecutor(object): ...@@ -569,6 +569,7 @@ class DistributedExecutor(object):
train_metric.reset_states() train_metric.reset_states()
# Reaches the end of training and saves the last checkpoint. # Reaches the end of training and saves the last checkpoint.
if last_save_checkpoint_step < total_steps:
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment