Commit 08189186 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389019529
parent d088f0d5
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os import os
from typing import List from typing import List
from absl import logging
import gin import gin
import orbit import orbit
...@@ -119,6 +120,58 @@ class EMACheckpointing: ...@@ -119,6 +120,58 @@ class EMACheckpointing:
self._optimizer.swap_weights() self._optimizer.swap_weights()
class RecoveryAction:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
self.checkpoint_manager = checkpoint_manager
def __call__(self, _):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning('Recovering the model from checkpoint: %s.',
checkpoint_path)
class RecoveryCondition:
"""Recovery Condition."""
def __init__(self,
global_step: tf.Variable,
loss_upper_bound: float,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.global_step = global_step
def __call__(self, outputs: orbit.runner.Output):
loss_value = outputs['training_loss']
if tf.math.is_nan(loss_value):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
'The loss value is NaN after training loop and it happens %d times.'
% self.recover_counter)
return True
if (self.global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
)
return True
return False
@gin.configurable @gin.configurable
def get_eval_actions( def get_eval_actions(
params: config_definitions.ExperimentConfig, params: config_definitions.ExperimentConfig,
...@@ -140,9 +193,10 @@ def get_eval_actions( ...@@ -140,9 +193,10 @@ def get_eval_actions(
@gin.configurable @gin.configurable
def get_train_actions(params: config_definitions.ExperimentConfig, def get_train_actions(
trainer: base_trainer.Trainer, params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]: model_dir: str,
checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
"""Gets train actions for TFM trainer.""" """Gets train actions for TFM trainer."""
train_actions = [] train_actions = []
# Adds pruning callback actions. # Adds pruning callback actions.
...@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig, ...@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
model=trainer.model, model=trainer.model,
optimizer=trainer.optimizer)) optimizer=trainer.optimizer))
if params.trainer.recovery_max_trials >= 0:
recovery_condition = RecoveryCondition(
global_step=trainer.global_step,
loss_upper_bound=params.trainer.loss_upper_bound,
recovery_begin_steps=params.trainer.recovery_begin_steps,
recovery_max_trials=params.trainer.recovery_max_trials,
)
recover_action = orbit.actions.ConditionalAction(
condition=recovery_condition,
action=RecoveryAction(checkpoint_manager),
)
train_actions.append(recover_action)
return train_actions return train_actions
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -35,17 +37,14 @@ class TestModel(tf.Module): ...@@ -35,17 +37,14 @@ class TestModel(tf.Module):
return self.value return self.value
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class ActionsTest(tf.test.TestCase, parameterized.TestCase): class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations()) @combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_ema_checkpointing(self, distribution): def test_ema_checkpointing(self, distribution):
with distribution.scope(): with distribution.scope():
directory = self.create_tempdir() directory = self.create_tempdir()
...@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
# Checks model.value is 0 after swapping. # Checks model.value is 0 after swapping.
self.assertEqual(model(), 0) self.assertEqual(model(), 0)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_recovery_condition(self, distribution):
with distribution.scope():
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': 0.6}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer): ...@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint.""" """Accesses the training checkpoint."""
return self._checkpoint return self._checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def add_recovery(self, params: TrainerConfig, def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager): checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0: if params.recovery_max_trials >= 0:
...@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer): ...@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
def train_loop_end(self): def train_loop_end(self):
"""See base class.""" """See base class."""
self.join() self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {} logs = {}
for metric in self.train_metrics + [self.train_loss]: for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
......
...@@ -19,7 +19,6 @@ import os ...@@ -19,7 +19,6 @@ import os
import sys import sys
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import orbit import orbit
import portpicker import portpicker
import tensorflow as tf import tensorflow as tf
...@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -337,61 +336,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertTrue( self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json'))) tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_recovery(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
loss_upper_bound=0.5,
recovery_max_trials=2,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
before_weights = trainer.model.get_weights()
_ = trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights = trainer.model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
recovery_max_trials=0,
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
task = mock_task.MockTask(config.task, logging_dir=model_dir)
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
def test_model_with_compiled_loss(self): def test_model_with_compiled_loss(self):
task = mock_task.MockTask() task = mock_task.MockTask()
model = task.build_model() model = task.build_model()
......
...@@ -87,8 +87,6 @@ def run_experiment( ...@@ -87,8 +87,6 @@ def run_experiment(
step_counter=trainer.global_step, step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval, checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize) init_fn=trainer.initialize)
# Adds recovery handling.
trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager)
else: else:
checkpoint_manager = None checkpoint_manager = None
...@@ -105,7 +103,8 @@ def run_experiment( ...@@ -105,7 +103,8 @@ def run_experiment(
(save_summary) else None, (save_summary) else None,
summary_interval=params.trainer.summary_interval if summary_interval=params.trainer.summary_interval if
(save_summary) else None, (save_summary) else None,
train_actions=actions.get_train_actions(params, trainer, model_dir), train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir)) eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -30,6 +31,7 @@ from official.common import registry_imports ...@@ -30,6 +31,7 @@ from official.common import registry_imports
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.utils.testing import mock_task
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
params=params, params=params,
model_dir=model_dir, model_dir=model_dir,
run_post_eval=run_post_eval) run_post_eval=run_post_eval)
print(logs)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'train_and_eval'],
))
def test_recovery_nan_error(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task = mock_task.MockTask(params.task, logging_dir=model_dir)
# Set the loss to NaN to trigger RunTimeError.
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train'],
))
def test_recovery(self, distribution_strategy, flag_mode):
loss_threshold = 1.0
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
params.trainer.loss_upper_bound = loss_threshold
params.trainer.recovery_max_trials = 1
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
# Saves a checkpoint for reference.
model = task.build_model()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
before_weights = model.get_weights()
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([loss_threshold], tf.float32) + aux_losses
task.build_losses = build_losses
model, _ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
def test_parse_configuration(self): def test_parse_configuration(self):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
......
...@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer): ...@@ -43,11 +43,6 @@ class RankingTrainer(base_trainer.Trainer):
def train_loop_end(self) -> Dict[str, float]: def train_loop_end(self) -> Dict[str, float]:
"""See base class.""" """See base class."""
self.join() self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {} logs = {}
for metric in self.train_metrics + [self.train_loss]: for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment