# Lint as: python3 # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Progressive Trainer implementation. The trainer implements the Orbit `StandardTrainable` and `StandardEvaluable` interfaces. Trainers inside this project should be interchangable and independent on model architectures and tasks. """ import os from typing import Any, Optional # Import libraries from absl import logging import dataclasses import gin import orbit import tensorflow as tf from official.core import base_task from official.core import base_trainer as trainer_lib from official.core import config_definitions from official.modeling.progressive import policies ExperimentConfig = config_definitions.ExperimentConfig @dataclasses.dataclass class ProgressiveTrainerConfig(config_definitions.TrainerConfig): """Configuration for progressive trainer. Attributes: progressive: A task-specific config. Users can subclass ProgressiveConfig and define any task-specific settings in their subclass. export_checkpoint: A bool. Whether to export checkpoints in non-progressive manner (without the volatiles wrapper) such that your down-stream tasks can load checkpoints from a progressive trainer as if it is a regular checkpoint. export_checkpoint_interval: A bool. The number of steps between exporting checkpoints. If None (by default), will use the same value as TrainerConfig.checkpoint_interval. export_only_final_stage_ckpt: A bool. Whether to just export checkpoints during the final progressive training stage. In other words, whether to not export small, partial models. In many cases, it is not meaningful to finetune a small, partial model in down-stream tasks. """ progressive: Optional[policies.ProgressiveConfig] = None export_checkpoint: bool = True export_checkpoint_interval: Optional[int] = None export_only_final_stage_ckpt: bool = True class CheckpointWithHooks(tf.train.Checkpoint): """Same as tf.train.Checkpoint but supports hooks. When running continuous_eval jobs, when a new checkpoint arrives, we have to update our model and optimizer etc. to match the stage_id of the checkpoint. However, when orbit loads a checkpoint, it does not inform us. So we use this class to update our model to the correct stage before checkpoint restore. """ def __init__(self, before_load_hook, **kwargs): self._before_load_hook = before_load_hook super(CheckpointWithHooks, self).__init__(**kwargs) # override def read(self, save_path, options=None): self._before_load_hook(save_path) logging.info('Ran before_load_hook.') super(CheckpointWithHooks, self).read(save_path=save_path, options=options) @gin.configurable class ProgressiveTrainer(trainer_lib.Trainer): """Implements the progressive trainer shared for TensorFlow models.""" def __init__( self, config: ExperimentConfig, prog_task: base_task.Task, # also implemented ProgressivePolicy. ckpt_dir: str = '', train: bool = True, evaluate: bool = True, checkpoint_exporter: Any = None): """Initialize common trainer for TensorFlow models. Args: config: An `ExperimentConfig` instance specifying experiment config. prog_task: An instance both implemented policies.ProgressivePolicy and base_task.Task. ckpt_dir: Checkpoint directory. train: bool, whether or not this trainer will be used for training. default to True. evaluate: bool, whether or not this trainer will be used for evaluation. default to True. checkpoint_exporter: an object that has the `maybe_export_checkpoint` interface. """ # Gets the current distribution strategy. If not inside any strategy scope, # it gets a single-replica no-op strategy. self._strategy = tf.distribute.get_strategy() self._config = config self._task = prog_task # Directory for non-progressive checkpoint self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts') tf.io.gfile.makedirs(self._export_ckpt_dir) # Receive other checkpoint export, e.g, best checkpoint exporter. # TODO(lehou): unify the checkpoint exporting logic, although the default # setting does not use checkpoint_exporter. self._checkpoint_exporter = checkpoint_exporter self._global_step = orbit.utils.create_global_step() self._checkpoint = CheckpointWithHooks( before_load_hook=self._update_pt_stage_from_ckpt, global_step=self.global_step, **self._task.cur_checkpoint_items) self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) self._validation_loss = tf.keras.metrics.Mean( 'validation_loss', dtype=tf.float32) self._train_metrics = self.task.build_metrics( training=True) + self.model.metrics self._validation_metrics = self.task.build_metrics( training=False) + self.model.metrics if train: orbit.StandardTrainer.__init__( self, None, # Manage train_dataset by ourselves, not by StandardTrainer. options=orbit.StandardTrainerOptions( use_tf_while_loop=config.trainer.train_tf_while_loop, use_tf_function=config.trainer.train_tf_function)) if evaluate: orbit.StandardEvaluator.__init__( self, None, # Manage train_dataset by ourselves, not by StandardEvaluator. options=orbit.StandardEvaluatorOptions( use_tf_function=config.trainer.eval_tf_function)) @property def model(self): return self._task.cur_model @property def optimizer(self): return self._task.cur_optimizer # override @property def train_dataset(self): """Overriding StandardTrainer.train_dataset.""" return self._task.cur_train_dataset # override @train_dataset.setter def train_dataset(self, _): raise SyntaxError('Please do not set train_dataset. Progressive training ' 'relies on progressive policy to manager train dataset.') # override @property def eval_dataset(self): """Overriding StandardEvaluator.eval_dataset.""" return self._task.cur_eval_dataset # override @eval_dataset.setter def eval_dataset(self, _): raise SyntaxError('Please do not set eval_dataset. Progressive training ' 'relies on progressive policy to manager eval dataset.') def train_loop_end(self): """See base class.""" logs = {} for metric in self.train_metrics + [self.train_loss]: logs[metric.name] = metric.result() metric.reset_states() if callable(self.optimizer.learning_rate): logs['learning_rate'] = self.optimizer.learning_rate( self.optimizer.iterations) else: logs['learning_rate'] = self.optimizer.learning_rate self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir) if self._task.is_stage_advancing(self.global_step.numpy()): old_train_dataset = self.train_dataset # Update progressive properties self._task.update_pt_stage(self.global_step.numpy()) # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will # rebuild the train and eval functions with the updated model. self._train_loop_fn = None self._eval_loop_fn = None if self.train_dataset != old_train_dataset: # Setting `self._train_iter` to None will rebuild the dataset iterator. self._train_iter = None return logs def _update_pt_stage_from_ckpt(self, ckpt_file): """Update stage properties based on the global_step variable in a ckpt file. Before loading variables from a checkpoint file, we need to go to the correct stage and build corresponding model and optimizer, to make sure that we retore variables of the right model and optimizer. Args: ckpt_file: Checkpoint file that will be restored/read from. """ if not ckpt_file: return ckpt = tf.train.Checkpoint(global_step=self.global_step) ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched() if self._task.is_stage_advancing(self.global_step.numpy()): old_train_dataset = self.train_dataset # Update progressive properties self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False) # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will # rebuild the train and eval functions with the updated model. self._train_loop_fn = None self._eval_loop_fn = None if self.train_dataset != old_train_dataset: # Setting `self._train_iter` to None will rebuild the dataset iterator. self._train_iter = None def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir): """Export checkpoints in non-progressive format. This basically removes the wrapping of self._task.cur_checkpoint_items -- just save the model, optimizer, etc., directly. The purpose is to let your down-stream tasks to use these checkpoints. Args: export_ckpt_dir: A str. folder of exported checkpoints. """ if not self.config.trainer.export_checkpoint: logging.info('Not exporting checkpoints.') return if not self._task.is_last_stage and ( self.config.trainer.export_only_final_stage_ckpt): logging.info('Not exporting checkpoints until the last stage.') return global_step_np = self.global_step.numpy() if self.config.trainer.export_checkpoint_interval is None: step_interval = self.config.trainer.checkpoint_interval else: step_interval = self.config.trainer.export_checkpoint_interval if global_step_np % step_interval != 0: logging.info('Not exporting checkpoints in global step: %d.', global_step_np) return # Create a checkpoint object just now, to make sure we use # progressive_policy.cur_model and progressive_policy.cur_optimizer of the # current stage. if hasattr(self.model, 'checkpoint_items'): checkpoint_items = self.model.checkpoint_items else: checkpoint_items = {} checkpoint = tf.train.Checkpoint( global_step=self.global_step, model=self.model, optimizer=self.optimizer, **checkpoint_items) file_prefix = os.path.join(export_ckpt_dir, 'ckpt-{}'.format(global_step_np)) checkpoint.save(file_prefix=file_prefix) logging.info('Checkpoints exported: %s.', file_prefix)