Commit ea1353c6 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374238894
parent 76df72b4
...@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig): ...@@ -51,6 +51,9 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_checkpoint_interval: A bool. The number of steps between exporting export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval. TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to not export small, partial models. In many cases, it is not meaningful to
...@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig): ...@@ -59,6 +62,7 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
progressive: Optional[policies.ProgressiveConfig] = None progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None export_checkpoint_interval: Optional[int] = None
export_max_to_keep: Optional[int] = None
export_only_final_stage_ckpt: bool = True export_only_final_stage_ckpt: bool = True
...@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -98,6 +102,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Directory for non-progressive checkpoint # Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts') self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir) tf.io.gfile.makedirs(self._export_ckpt_dir)
self._export_ckpt_manager = None
# Receive other checkpoint export, e.g, best checkpoint exporter. # Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default # TODO(lehou): unify the checkpoint exporting logic, although the default
...@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -194,6 +199,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator. # Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
return logs return logs
def _update_pt_stage_from_ckpt(self, ckpt_file): def _update_pt_stage_from_ckpt(self, ckpt_file):
...@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -226,6 +235,10 @@ class ProgressiveTrainer(trainer_lib.Trainer):
# Setting `self._train_iter` to None will rebuild the dataset iterator. # Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir): def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format. """Export checkpoints in non-progressive format.
...@@ -244,17 +257,7 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -244,17 +257,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
logging.info('Not exporting checkpoints until the last stage.') logging.info('Not exporting checkpoints until the last stage.')
return return
global_step_np = self.global_step.numpy() if self._export_ckpt_manager is None:
if self.config.trainer.export_checkpoint_interval is None:
step_interval = self.config.trainer.checkpoint_interval
else:
step_interval = self.config.trainer.export_checkpoint_interval
if global_step_np % step_interval != 0 and (
global_step_np < self._config.trainer.train_steps):
logging.info('Not exporting checkpoints in global step: %d.',
global_step_np)
return
# Create a checkpoint object just now, to make sure we use # Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the # progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage. # current stage.
...@@ -267,7 +270,22 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -267,7 +270,22 @@ class ProgressiveTrainer(trainer_lib.Trainer):
model=self.model, model=self.model,
optimizer=self.optimizer, optimizer=self.optimizer,
**checkpoint_items) **checkpoint_items)
file_prefix = os.path.join(export_ckpt_dir,
'ckpt-{}'.format(global_step_np)) max_to_keep = self.config.trainer.export_max_to_keep or (
checkpoint.save(file_prefix=file_prefix) self.config.trainer.max_to_keep)
logging.info('Checkpoints exported: %s.', file_prefix) checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
self.config.trainer.checkpoint_interval)
self._export_ckpt_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_ckpt_dir,
checkpoint_name='ckpt',
step_counter=self.global_step,
max_to_keep=max_to_keep,
checkpoint_interval=checkpoint_interval,
)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=True)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment