Commit 6dbdb08c authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Refactors the run_experiment function for better reusability.

PiperOrigin-RevId: 458550388
parent f1add1bc
...@@ -32,7 +32,29 @@ from official.core import train_utils ...@@ -32,7 +32,29 @@ from official.core import train_utils
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment( class OrbitExperimentRunner:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(AbstractExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
return MyCheckpointManager(*args)
# In user code
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def __init__(
self,
distribution_strategy: tf.distribute.Strategy, distribution_strategy: tf.distribute.Strategy,
task: base_task.Task, task: base_task.Task,
mode: str, mode: str,
...@@ -44,14 +66,14 @@ def run_experiment( ...@@ -44,14 +66,14 @@ def run_experiment(
eval_actions: Optional[List[orbit.Action]] = None, eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None, trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ):
"""Runs train/eval configured by the experiment params. """Constructor.
Args: Args:
distribution_strategy: A distribution distribution_strategy. distribution_strategy: A distribution strategy.
task: A Task instance. task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' mode: A 'str', specifying the mode. Can be 'train', 'eval',
or 'continuous_eval'. 'train_and_eval' or 'continuous_eval'.
params: ExperimentConfig instance. params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs run_post_eval: Whether to run post eval once after training, metrics logs
...@@ -59,102 +81,226 @@ def run_experiment( ...@@ -59,102 +81,226 @@ def run_experiment(
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions. train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions. eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the trainer: the base_trainer.Trainer instance. It should be created within
strategy.scope(). the strategy.scope().
controller_cls: The controller class to manage the train and eval process. controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass. Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
""" """
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
self._model_dir = model_dir
self._mode = mode
self._run_post_eval = run_post_eval
with distribution_strategy.scope(): self._trainer = trainer or self._build_trainer(
if not trainer:
trainer = train_utils.create_trainer(
params,
task, task,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval)
checkpoint_exporter=maybe_create_best_ckpt_exporter( assert self.trainer is not None
params, model_dir)) self._checkpoint_manager = self._maybe_build_checkpoint_manager()
self._controller = self._build_controller(
trainer=self.trainer if 'train' in mode else None,
evaluator=self.trainer,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
@property
def params(self) -> config_definitions.ExperimentConfig:
return self._params
@property
def model_dir(self) -> str:
return self._model_dir
@property
def trainer(self) -> base_trainer.Trainer:
return self._trainer
@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
return self._checkpoint_manager
@property
def controller(self) -> orbit.Controller:
return self._controller
def _build_trainer(self, task: base_task.Task, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Create trainer."""
with self.strategy.scope():
trainer = train_utils.create_trainer(
self.params,
task,
train=train,
evaluate=evaluate,
checkpoint_exporter=self._build_best_checkpoint_exporter())
return trainer
def _build_best_checkpoint_exporter(self):
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
if trainer.checkpoint: def _maybe_build_checkpoint_manager(
if model_dir is None: self) -> Optional[tf.train.CheckpointManager]:
"""Maybe create a CheckpointManager."""
assert self.trainer is not None
if self.trainer.checkpoint:
if self.model_dir is None:
raise ValueError('model_dir must be specified, but got None') raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, self.trainer.checkpoint,
directory=model_dir, directory=self.model_dir,
max_to_keep=params.trainer.max_to_keep, max_to_keep=self.params.trainer.max_to_keep,
step_counter=trainer.global_step, step_counter=self.trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval, checkpoint_interval=self.params.trainer.checkpoint_interval,
init_fn=trainer.initialize) init_fn=self.trainer.initialize)
else: else:
checkpoint_manager = None checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions train_actions = [] if not train_actions else train_actions
if trainer:
train_actions += actions.get_train_actions( train_actions += actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager) self.params,
trainer,
self.model_dir,
checkpoint_manager=self.checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions eval_actions = [] if not eval_actions else eval_actions
eval_actions += actions.get_eval_actions(params, trainer, model_dir) if evaluator:
eval_actions += actions.get_eval_actions(self.params, evaluator,
self.model_dir)
controller = controller_cls( controller = controller_cls(
strategy=distribution_strategy, strategy=self.strategy,
trainer=trainer if 'train' in mode else None, trainer=trainer,
evaluator=trainer, evaluator=evaluator,
global_step=trainer.global_step, global_step=self.trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop, steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager, checkpoint_manager=self.checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, summary_dir=os.path.join(self.model_dir, 'train') if
eval_summary_dir=os.path.join(model_dir,
params.trainer.validation_summary_subdir) if
(save_summary) else None, (save_summary) else None,
summary_interval=params.trainer.summary_interval if eval_summary_dir=os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=self.params.trainer.summary_interval if
(save_summary) else None, (save_summary) else None,
train_actions=train_actions, train_actions=train_actions,
eval_actions=eval_actions) eval_actions=eval_actions)
return controller
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode = self._mode
params = self.params
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope(): with self.strategy.scope():
if mode == 'train' or mode == 'train_and_post_eval': if mode == 'train' or mode == 'train_and_post_eval':
controller.train(steps=params.trainer.train_steps) self.controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval': elif mode == 'train_and_eval':
controller.train_and_evaluate( self.controller.train_and_evaluate(
train_steps=params.trainer.train_steps, train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps, eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval) eval_interval=params.trainer.validation_interval)
elif mode == 'eval': elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps) self.controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval': elif mode == 'continuous_eval':
def timeout_fn(): def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps: if self.trainer.global_step.numpy() >= params.trainer.train_steps:
return True return True
return False return False
controller.evaluate_continuously( self.controller.evaluate_continuously(
steps=params.trainer.validation_steps, steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout, timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn) timeout_fn=timeout_fn)
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(trainer.model) num_params = train_utils.try_count_params(self.trainer.model)
if num_params is not None: if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.', logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6) num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model) flops = train_utils.try_count_flops(self.trainer.model)
if flops is not None: if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.', logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2) flops / 10.**9 / 2)
if run_post_eval or mode == 'train_and_post_eval': if self._run_post_eval or mode == 'train_and_post_eval':
with distribution_strategy.scope(): with self.strategy.scope():
return trainer.model, controller.evaluate( return self.trainer.model, self.controller.evaluate(
steps=params.trainer.validation_steps) steps=params.trainer.validation_steps)
else: else:
return trainer.model, {} return self.trainer.model, {}
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
runner = OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
trainer=trainer,
controller_cls=controller_cls,
)
return runner.run()
...@@ -117,6 +117,61 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -117,6 +117,61 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir, model_dir=model_dir,
run_post_eval=run_post_eval) run_post_eval=run_post_eval)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end_class(self, distribution_strategy, flag_mode,
run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution_strategy=[ distribution_strategy=[
...@@ -148,12 +203,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -148,12 +203,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses task.build_losses = build_losses
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
train_lib.run_experiment( train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=flag_mode, mode=flag_mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir).run()
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
...@@ -194,12 +249,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -194,12 +249,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses task.build_losses = build_losses
model, _ = train_lib.run_experiment( model, _ = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=flag_mode, mode=flag_mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir).run()
after_weights = model.get_weights() after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights): for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right) self.assertAllEqual(left, right)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment