Commit 9f88ce51 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374238894
parent fd45760c
......@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
......@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None
export_max_to_keep: Optional[int] = None
export_only_final_stage_ckpt: bool = True
......@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir)
self._export_ckpt_manager = None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
......@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
return logs
def _update_pt_stage_from_ckpt(self, ckpt_file):
......@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format.
......@@ -244,17 +257,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
logging.info('Not exporting checkpoints until the last stage.')
return
global_step_np = self.global_step.numpy()
if self.config.trainer.export_checkpoint_interval is None:
step_interval = self.config.trainer.checkpoint_interval
else:
step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0 and (
global_step_np < self._config.trainer.train_steps):
logging.info('Not exporting checkpoints in global step: %d.',
global_step_np)
return
if self._export_ckpt_manager is None:
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
......@@ -267,7 +270,22 @@ class ProgressiveTrainer(trainer_lib.Trainer):
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
file_prefix = os.path.join(export_ckpt_dir,
'ckpt-{}'.format(global_step_np))
checkpoint.save(file_prefix=file_prefix)
logging.info('Checkpoints exported: %s.', file_prefix)
max_to_keep = self.config.trainer.export_max_to_keep or (
self.config.trainer.max_to_keep)
checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
self.config.trainer.checkpoint_interval)
self._export_ckpt_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_ckpt_dir,
checkpoint_name='ckpt',
step_counter=self.global_step,
max_to_keep=max_to_keep,
checkpoint_interval=checkpoint_interval,
)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=True)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment