Commit 4d94d3d3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Prototype loss blowup recovery in the base trainer.

When loss is NaN, the weights should also be NaN.

PiperOrigin-RevId: 341886095
parent e9635cee
......@@ -20,6 +20,8 @@ The base trainer implements the Orbit `StandardTrainable` and
interchangable and independent on model architectures and tasks.
"""
from typing import Optional
from absl import logging
import gin
import orbit
import tensorflow as tf
......@@ -57,6 +59,57 @@ def create_optimizer(trainer_config: TrainerConfig,
return optimizer
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN after training loop and it happens %d times." %
self.recover_counter)
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
@gin.configurable
class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
......@@ -90,8 +143,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._config = config
self._task = task
self._model = model
self._checkpoint_exporter = checkpoint_exporter
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
self._recovery = None
# global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
......@@ -223,8 +277,22 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Accesses the training checkpoint."""
return self._checkpoint
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
self._recovery = Recovery(
loss_upper_bound=params.loss_upper_bound,
recovery_begin_steps=params.recovery_begin_steps,
recovery_max_trials=params.recovery_max_trials,
checkpoint_manager=checkpoint_manager)
def train_loop_end(self):
"""See base class."""
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
......
......@@ -15,9 +15,9 @@
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -107,15 +107,13 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else:
self.assertIsInstance(
trainer.optimizer,
self.assertIsInstance(trainer.optimizer,
tf.keras.mixed_precision.LossScaleOptimizer)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
@combinations.generate(all_strategy_combinations())
def test_export_best_ckpt(self, distribution):
def test_export_best_ckpt(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt',
......@@ -135,6 +133,58 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_recovery(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
loss_upper_bound=0.5,
recovery_max_trials=2,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
before_weights = trainer.model.get_weights()
_ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights = trainer.model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
recovery_max_trials=0,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
task = mock_task.MockTask(config.task, logging_dir=model_dir)
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
if __name__ == '__main__':
tf.test.main()
......@@ -209,6 +209,13 @@ class TrainerConfig(base_config.Config):
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
# Blowup recovery.
loss_upper_bound: float = 1e6
recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
# When max trials < 0, no recovery module; max trials = 0, we will check
# the condition and fail the job if the condition happens; max trials > 0,
# we will retore the model states.
recovery_max_trials: int = 0
@dataclasses.dataclass
......
......@@ -185,6 +185,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
# Adds recovery handling.
trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager)
else:
checkpoint_manager = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment