"....github/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "ecd2f176277db4f074e25a2c3646b04b51cec119"
Commit bde9cdca authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 356158181
parent 98839bd2
...@@ -32,6 +32,7 @@ from official.core import base_task ...@@ -32,6 +32,7 @@ from official.core import base_task
from official.core import base_trainer as trainer_lib from official.core import base_trainer as trainer_lib
from official.core import config_definitions from official.core import config_definitions
from official.modeling.progressive import policies from official.modeling.progressive import policies
from official.modeling.progressive import utils
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
...@@ -61,26 +62,6 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig): ...@@ -61,26 +62,6 @@ class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
export_only_final_stage_ckpt: bool = True export_only_final_stage_ckpt: bool = True
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
When running continuous_eval jobs, when a new checkpoint arrives, we have to
update our model and optimizer etc. to match the stage_id of the checkpoint.
However, when orbit loads a checkpoint, it does not inform us. So we use this
class to update our model to the correct stage before checkpoint restore.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
@gin.configurable @gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer): class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models.""" """Implements the progressive trainer shared for TensorFlow models."""
...@@ -124,7 +105,7 @@ class ProgressiveTrainer(trainer_lib.Trainer): ...@@ -124,7 +105,7 @@ class ProgressiveTrainer(trainer_lib.Trainer):
self._global_step = orbit.utils.create_global_step() self._global_step = orbit.utils.create_global_step()
self._checkpoint = CheckpointWithHooks( self._checkpoint = utils.CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt, before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step, global_step=self.global_step,
**self._task.cur_checkpoint_items) **self._task.cur_checkpoint_items)
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
"""Util classes and functions.""" """Util classes and functions."""
from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
...@@ -29,3 +32,25 @@ class VolatileTrackable(tracking.AutoTrackable): ...@@ -29,3 +32,25 @@ class VolatileTrackable(tracking.AutoTrackable):
for k, v in kwargs.items(): for k, v in kwargs.items():
delattr(self, k) # untrack this object delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object setattr(self, k, v) # track the new object
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment