"vscode:/vscode.git/clone" did not exist on "5d93a950eeebc198cc6b8c19d404e688e0ef7b39"
Commit a89c4c07 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Cleanup] Remove legacy recovery module inside the tf model garden.

PiperOrigin-RevId: 442311238
parent 339bc9bf
......@@ -33,57 +33,6 @@ ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN or out of range after training loop and "
f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
......
......@@ -150,30 +150,6 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return self.eval_global_step.numpy()
class RecoveryTest(tf.test.TestCase):
def test_recovery_module(self):
ckpt = tf.train.Checkpoint(v=tf.Variable(1, dtype=tf.int32))
model_dir = self.get_temp_dir()
manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
recovery_module = trainer_lib.Recovery(
loss_upper_bound=1.0,
checkpoint_manager=manager,
recovery_begin_steps=1,
recovery_max_trials=1)
self.assertFalse(recovery_module.should_recover(1.1, 0))
self.assertFalse(recovery_module.should_recover(0.1, 1))
self.assertTrue(recovery_module.should_recover(1.1, 2))
# First triggers the recovery once.
recovery_module.maybe_recover(1.1, 10)
# Second time, it raises.
with self.assertRaisesRegex(
RuntimeError, 'The loss value is NaN .*'):
recovery_module.maybe_recover(1.1, 10)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment