Commit 9f88ce51 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374238894
parent fd45760c
...@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig): ...@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_checkpoint_interval: A bool. The number of steps between exporting export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval. TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to not export small, partial models. In many cases, it is not meaningful to
...@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig): ...@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
progressive: Optional[policies.ProgressiveConfig] = None progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None export_checkpoint_interval: Optional[int] = None
export_max_to_keep: Optional[int] = None
export_only_final_stage_ckpt: bool = True export_only_final_stage_ckpt: bool = True
...@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Directory for non-progressive checkpoint # Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts') self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir) tf.io.gfile.makedirs(self._export_ckpt_dir)
self._export_ckpt_manager = None
# Receive other checkpoint export, e.g, best checkpoint exporter. # Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default # TODO(lehou): unify the checkpoint exporting logic, although the default
...@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator. # Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
return logs return logs
def _update_pt_stage_from_ckpt(self, ckpt_file): def _update_pt_stage_from_ckpt(self, ckpt_file):
...@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator. # Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir): def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format. """Export checkpoints in non-progressive format.
...@@ -244,30 +257,35 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -244,30 +257,35 @@ class ProgressiveTrainer(trainer_lib.Trainer):
logging.info('Not exporting checkpoints until the last stage.') logging.info('Not exporting checkpoints until the last stage.')
return return
global_step_np = self.global_step.numpy() if self._export_ckpt_manager is None:
if self.config.trainer.export_checkpoint_interval is None: # Create a checkpoint object just now, to make sure we use
step_interval = self.config.trainer.checkpoint_interval # progressive_policy.cur_model and progressive_policy.cur_optimizer of the
else: # current stage.
step_interval = self.config.trainer.export_checkpoint_interval if hasattr(self.model, 'checkpoint_items'):
if global_step_np % step_interval != 0 and ( checkpoint_items = self.model.checkpoint_items
global_step_np < self._config.trainer.train_steps): else:
logging.info('Not exporting checkpoints in global step: %d.', checkpoint_items = {}
global_step_np) checkpoint = tf.train.Checkpoint(
return global_step=self.global_step,
model=self.model,
# Create a checkpoint object just now, to make sure we use optimizer=self.optimizer,
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the **checkpoint_items)
# current stage.
if hasattr(self.model, 'checkpoint_items'): max_to_keep = self.config.trainer.export_max_to_keep or (
checkpoint_items = self.model.checkpoint_items self.config.trainer.max_to_keep)
else: checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
checkpoint_items = {} self.config.trainer.checkpoint_interval)
checkpoint = tf.train.Checkpoint( self._export_ckpt_manager = tf.train.CheckpointManager(
global_step=self.global_step, checkpoint,
model=self.model, directory=export_ckpt_dir,
optimizer=self.optimizer, checkpoint_name='ckpt',
**checkpoint_items) step_counter=self.global_step,
file_prefix = os.path.join(export_ckpt_dir, max_to_keep=max_to_keep,
'ckpt-{}'.format(global_step_np)) checkpoint_interval=checkpoint_interval,
checkpoint.save(file_prefix=file_prefix) )
logging.info('Checkpoints exported: %s.', file_prefix)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=True)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment