Commit 88a6cf61 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 377381104
parent 069bdd28
......@@ -284,8 +284,11 @@ class ProgressiveTrainer(trainer_lib.Trainer):
checkpoint_interval=checkpoint_interval,
)
# Make sure we export the last checkpoint.
last_checkpoint = (
self.global_step.numpy() == self._config.trainer.train_steps)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=True)
check_interval=not last_checkpoint)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment