Commit d022a749 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Improve test coverage. Add tests for Recovery module.

PiperOrigin-RevId: 409277297
parent 4b5560cd
...@@ -75,8 +75,8 @@ class Recovery: ...@@ -75,8 +75,8 @@ class Recovery:
self.recover_counter += 1 self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials: if self.recover_counter > self.recovery_max_trials:
raise RuntimeError( raise RuntimeError(
"The loss value is NaN after training loop and it happens %d times." % "The loss value is NaN or out of range after training loop and "
self.recover_counter) f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint. # Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize() checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning( logging.warning(
......
...@@ -150,6 +150,30 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer): ...@@ -150,6 +150,30 @@ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
return self.eval_global_step.numpy() return self.eval_global_step.numpy()
class RecoveryTest(tf.test.TestCase):
def test_recovery_module(self):
ckpt = tf.train.Checkpoint(v=tf.Variable(1, dtype=tf.int32))
model_dir = self.get_temp_dir()
manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
recovery_module = trainer_lib.Recovery(
loss_upper_bound=1.0,
checkpoint_manager=manager,
recovery_begin_steps=1,
recovery_max_trials=1)
self.assertFalse(recovery_module.should_recover(1.1, 0))
self.assertFalse(recovery_module.should_recover(0.1, 1))
self.assertTrue(recovery_module.should_recover(1.1, 2))
# First triggers the recovery once.
recovery_module.maybe_recover(1.1, 10)
# Second time, it raises.
with self.assertRaisesRegex(
RuntimeError, 'The loss value is NaN .*'):
recovery_module.maybe_recover(1.1, 10)
class TrainerTest(tf.test.TestCase, parameterized.TestCase): class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment