Commit bde9cdca authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 356158181
parent 98839bd2
......@@ -32,6 +32,7 @@ from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.progressive import policies
from official.modeling.progressive import utils
ExperimentConfig = config_definitions.ExperimentConfig
......@@ -61,26 +62,6 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_only_final_stage_ckpt: bool = True
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
When running continuous_eval jobs, when a new checkpoint arrives, we have to
update our model and optimizer etc. to match the stage_id of the checkpoint.
However, when orbit loads a checkpoint, it does not inform us. So we use this
class to update our model to the correct stage before checkpoint restore.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models."""
......@@ -124,7 +105,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
self._global_step = orbit.utils.create_global_step()
self._checkpoint = CheckpointWithHooks(
self._checkpoint = utils.CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step,
**self._task.cur_checkpoint_items)
......
......@@ -14,6 +14,9 @@
"""Util classes and functions."""
from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
......@@ -29,3 +32,25 @@ class VolatileTrackable(tracking.AutoTrackable):
for k, v in kwargs.items():
delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment