Commit 476f8e62 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Prototype loss blowup recovery in the base trainer.

When loss is NaN, the weights should also be NaN.

PiperOrigin-RevId: 341886095
parent 39774bc8
...@@ -20,6 +20,8 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -20,6 +20,8 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
from typing import Optional from typing import Optional
from absl import logging
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -57,6 +59,57 @@ def create_optimizer(trainer_config: TrainerConfig, ...@@ -57,6 +59,57 @@ def create_optimizer(trainer_config: TrainerConfig,
return optimizer return optimizer
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN after training loop and it happens %d times." %
self.recover_counter)
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
@gin.configurable @gin.configurable
class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
...@@ -90,8 +143,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -90,8 +143,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._config = config self._config = config
self._task = task self._task = task
self._model = model self._model = model
self._checkpoint_exporter = checkpoint_exporter
self._optimizer = optimizer self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
self._recovery = None
# global_step increases by 1 after each training iteration. # global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy() # We should have global_step.numpy() == self.optimizer.iterations.numpy()
...@@ -223,8 +277,22 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -223,8 +277,22 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Accesses the training checkpoint.""" """Accesses the training checkpoint."""
return self._checkpoint return self._checkpoint
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
self._recovery = Recovery(
loss_upper_bound=params.loss_upper_bound,
recovery_begin_steps=params.recovery_begin_steps,
recovery_max_trials=params.recovery_max_trials,
checkpoint_manager=checkpoint_manager)
def train_loop_end(self): def train_loop_end(self):
"""See base class.""" """See base class."""
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {} logs = {}
for metric in self.train_metrics + [self.train_loss]: for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
# ============================================================================== # ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer.""" """Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -107,15 +107,13 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,15 +107,13 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
elif mixed_precision_dtype == 'float16' and loss_scale is None: elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else: else:
self.assertIsInstance( self.assertIsInstance(trainer.optimizer,
trainer.optimizer, tf.keras.mixed_precision.LossScaleOptimizer)
tf.keras.mixed_precision.LossScaleOptimizer)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
@combinations.generate(all_strategy_combinations()) def test_export_best_ckpt(self):
def test_export_best_ckpt(self, distribution):
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt', best_checkpoint_export_subdir='best_ckpt',
...@@ -135,6 +133,58 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -135,6 +133,58 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue( self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json'))) tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_recovery(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
loss_upper_bound=0.5,
recovery_max_trials=2,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
before_weights = trainer.model.get_weights()
_ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights = trainer.model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
recovery_max_trials=0,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
task = mock_task.MockTask(config.task, logging_dir=model_dir)
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -209,6 +209,13 @@ class TrainerConfig(base_config.Config): ...@@ -209,6 +209,13 @@ class TrainerConfig(base_config.Config):
best_checkpoint_export_subdir: str = "" best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = "" best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher" best_checkpoint_metric_comp: str = "higher"
# Blowup recovery.
loss_upper_bound: float = 1e6
recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
# When max trials < 0, no recovery module; max trials = 0, we will check
# the condition and fail the job if the condition happens; max trials > 0,
# we will retore the model states.
recovery_max_trials: int = 0
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -185,6 +185,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -185,6 +185,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
step_counter=trainer.global_step, step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval, checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize) init_fn=trainer.initialize)
# Adds recovery handling.
trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager)
else: else:
checkpoint_manager = None checkpoint_manager = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment