Commit 08189186 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389019529
parent d088f0d5
......@@ -16,6 +16,7 @@
import os
from typing import List
from absl import logging
import gin
import orbit
......@@ -119,6 +120,58 @@ class EMACheckpointing:
self._optimizer.swap_weights()
class RecoveryAction:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
self.checkpoint_manager = checkpoint_manager
def __call__(self, _):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning('Recovering the model from checkpoint: %s.',
checkpoint_path)
class RecoveryCondition:
"""Recovery Condition."""
def __init__(self,
global_step: tf.Variable,
loss_upper_bound: float,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.global_step = global_step
def __call__(self, outputs: orbit.runner.Output):
loss_value = outputs['training_loss']
if tf.math.is_nan(loss_value):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
'The loss value is NaN after training loop and it happens %d times.'
% self.recover_counter)
return True
if (self.global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
)
return True
return False
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
......@@ -140,9 +193,10 @@ def get_eval_actions(
@gin.configurable
def get_train_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
def get_train_actions(
params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
model_dir: str,
checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
......@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
model=trainer.model,
optimizer=trainer.optimizer))
if params.trainer.recovery_max_trials >= 0:
recovery_condition = RecoveryCondition(
global_step=trainer.global_step,
loss_upper_bound=params.trainer.loss_upper_bound,
recovery_begin_steps=params.trainer.recovery_begin_steps,
recovery_max_trials=params.trainer.recovery_max_trials,
)
recover_action = orbit.actions.ConditionalAction(
condition=recovery_condition,
action=RecoveryAction(checkpoint_manager),
)
train_actions.append(recover_action)
return train_actions
......@@ -17,6 +17,8 @@
import os
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -35,17 +37,14 @@ class TestModel(tf.Module):
return self.value
def all_strategy_combinations():
return combinations.combine(
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
......@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_recovery_condition(self, distribution):
with distribution.scope():
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': 0.6}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
if __name__ == '__main__':
tf.test.main()
......@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
return self._checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
......@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
def train_loop_end(self):
"""See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
......
......@@ -19,7 +19,6 @@ import os
import sys
from absl.testing import parameterized
import numpy as np
import orbit
import portpicker
import tensorflow as tf
......@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_recovery(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
loss_upper_bound=0.5,
recovery_max_trials=2,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
before_weights = trainer.model.get_weights()
_ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights = trainer.model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
recovery_max_trials=0,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
task = mock_task.MockTask(config.task, logging_dir=model_dir)
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
def test_model_with_compiled_loss(self):
task = mock_task.MockTask()
model = task.build_model()
......
......@@ -87,8 +87,6 @@ def run_experiment(
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
# Adds recovery handling.
trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager)
else:
checkpoint_manager = None
......@@ -105,7 +103,8 @@ def run_experiment(
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None,
train_actions=actions.get_train_actions(params, trainer, model_dir),
train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode)
......
......@@ -19,6 +19,7 @@ import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -30,6 +31,7 @@ from official.common import registry_imports
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
......@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'train_and_eval'],
))
def test_recovery_nan_error(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task = mock_task.MockTask(params.task, logging_dir=model_dir)
# Set the loss to NaN to trigger RunTimeError.
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train'],
))
def test_recovery(self, distribution_strategy, flag_mode):
loss_threshold = 1.0
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
params.trainer.loss_upper_bound = loss_threshold
params.trainer.recovery_max_trials = 1
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
# Saves a checkpoint for reference.
model = task.build_model()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
before_weights = model.get_weights()
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([loss_threshold], tf.float32) + aux_losses
task.build_losses = build_losses
model, _ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
def test_parse_configuration(self):
model_dir = self.get_temp_dir()
......
......@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
def train_loop_end(self) -> Dict[str, float]:
"""See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment