Commit cf45af50 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds tf.distribute.experimental.ParameterServerStrategy support to Orbit.

PiperOrigin-RevId: 362729857
parent 8624cb23
...@@ -18,7 +18,7 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -18,7 +18,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be `StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
import functools
from absl import logging from absl import logging
import gin import gin
import orbit import orbit
...@@ -84,10 +84,85 @@ class Recovery: ...@@ -84,10 +84,85 @@ class Recovery:
"%f at step %d.", checkpoint_path, loss_value, global_step) "%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
def init_async(self):
"""Initializes the Async Trainer base class."""
assert isinstance(self._strategy, tf.distribute.Strategy)
self._is_async = isinstance(
self._strategy, tf.distribute.experimental.ParameterServerStrategy)
self._coordinator = None
if self._is_async:
self._coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(
self._strategy))
def join(self):
"""Join all async steps. Only useful in aysnc training."""
if getattr(self, "_is_async", False):
self._coordinator.join()
def create_train_loop_fn(self):
"""Creates a eval loop from the given step function and options."""
train_loop_fn = super().create_train_loop_fn()
if getattr(self, "_is_async", False):
def _async_loop_fn(iterator, num_steps):
self._coordinator.schedule(train_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return train_loop_fn
def create_eval_loop_fn(self, has_state: bool):
"""Creates a training loop from the given step function and options."""
eval_loop_fn = super().create_eval_loop_fn(has_state)
if getattr(self, "_is_async", False):
if has_state:
raise ValueError(
"Stateful eval loop is not supported in async training.")
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
assert state is None
assert reduce_fn is None
self._coordinator.schedule(eval_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return eval_loop_fn
def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if getattr(self, "_is_async", False):
per_worker_dataset_fn = functools.partial(
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
*args, **kwargs)
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
return self._coordinator.create_per_worker_dataset(per_worker_dataset_fn)
else:
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
*args, **kwargs)
@gin.configurable @gin.configurable
class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def __init__(self, def __init__(self,
config: ExperimentConfig, config: ExperimentConfig,
task: base_task.Task, task: base_task.Task,
...@@ -147,9 +222,11 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -147,9 +222,11 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._validation_metrics = self.task.build_metrics( self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics training=False) + self.model.metrics
self.init_async()
if train: if train:
train_dataset = orbit.utils.make_distributed_dataset( train_dataset = self.distribute_dataset(
self.strategy, self.task.build_inputs, self.config.task.train_data) self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__( orbit.StandardTrainer.__init__(
self, self,
train_dataset, train_dataset,
...@@ -159,9 +236,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -159,9 +236,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary)) use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate: if evaluate:
eval_dataset = orbit.utils.make_distributed_dataset( eval_dataset = self.distribute_dataset(
self.strategy, self.task.build_inputs, self.task.build_inputs, self.config.task.validation_data)
self.config.task.validation_data)
orbit.StandardEvaluator.__init__( orbit.StandardEvaluator.__init__(
self, self,
eval_dataset, eval_dataset,
...@@ -270,6 +346,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -270,6 +346,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def train_loop_end(self): def train_loop_end(self):
"""See base class.""" """See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint # Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly. # recovery accordingly.
if self._recovery: if self._recovery:
...@@ -324,6 +401,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -324,6 +401,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def eval_end(self, aggregated_logs=None): def eval_end(self, aggregated_logs=None):
"""Processes evaluation results.""" """Processes evaluation results."""
self.join()
logs = {} logs = {}
for metric in self.validation_metrics: for metric in self.validation_metrics:
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
......
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
"""Tests for tensorflow_models.core.trainers.trainer.""" """Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
import multiprocessing
import os import os
import sys
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import portpicker
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -26,6 +30,9 @@ from official.core import config_definitions as cfg ...@@ -26,6 +30,9 @@ from official.core import config_definitions as cfg
from official.core import train_lib from official.core import train_lib
from official.utils.testing import mock_task from official.utils.testing import mock_task
TPU_TEST = 'test_tpu' in sys.argv[0]
GPU_TEST = 'test_gpu' in sys.argv[0]
def all_strategy_combinations(): def all_strategy_combinations():
return combinations.combine( return combinations.combine(
...@@ -36,6 +43,113 @@ def all_strategy_combinations(): ...@@ -36,6 +43,113 @@ def all_strategy_combinations():
],) ],)
def create_in_process_cluster(num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name='worker',
task_index=i,
config=worker_config,
protocol='grpc')
for i in range(num_ps):
tf.distribute.Server(
cluster_spec, job_name='ps', task_index=i, protocol='grpc')
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer='grpc')
return cluster_resolver
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class MockAsyncTrainer(trainer_lib._AsyncTrainer):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def __init__(self):
self._strategy = tf.distribute.get_strategy()
self.init_async()
self.global_step = tf.Variable(
0,
dtype=tf.int64,
name='global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
self.eval_global_step = tf.Variable(
0,
dtype=tf.int64,
name='eval_global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardTrainer.__init__(
self, train_dataset, options=trainer_lib.orbit.StandardTrainerOptions())
eval_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardEvaluator.__init__(
self,
eval_dataset,
options=trainer_lib.orbit.StandardEvaluatorOptions(
use_tf_while_loop=True))
def train_loop_begin(self):
self.global_step.assign(0)
def train_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self):
self.join()
return self.global_step.numpy()
def eval_begin(self):
self.eval_global_step.assign(0)
def eval_step(self, iterator):
def replica_step(_):
self.eval_global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def eval_end(self):
self.join()
return self.eval_global_step.numpy()
class TrainerTest(tf.test.TestCase, parameterized.TestCase): class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
...@@ -71,6 +185,55 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,6 +185,55 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('training_loss', logs) self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs) self.assertIn('learning_rate', logs)
def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
trainer = MockAsyncTrainer()
trainer.init_async()
self.assertIsInstance(
trainer._coordinator,
tf.distribute.experimental.coordinator.ClusterCoordinator)
self.assertEqual(trainer.train(tf.constant(10)), 10)
self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
def test_async_trainer_validate(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('acc', logs)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution): def test_trainer_validate(self, distribution):
with distribution.scope(): with distribution.scope():
......
...@@ -68,21 +68,6 @@ class StandardTrainerOptions: ...@@ -68,21 +68,6 @@ class StandardTrainerOptions:
use_tpu_summary_optimization: bool = False use_tpu_summary_optimization: bool = False
def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions):
"""Creates a training loop from the given step function and options."""
if options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements standard functionality on top of the AbstractTrainer API. """Implements standard functionality on top of the AbstractTrainer API.
...@@ -119,6 +104,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -119,6 +104,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self._train_iter = None self._train_iter = None
self._train_loop_fn = None self._train_loop_fn = None
def create_train_loop_fn(self):
"""Creates a training loop from the current step function and options.
Returns:
The train loop function, i.e. wrapper of multiple train steps.
"""
train_step_fn = self.train_step
if self._train_options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if self._train_options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if self._train_options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]: def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of training. """Implements `num_steps` steps of training.
...@@ -132,8 +136,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -132,8 +136,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self.train_loop_begin() self.train_loop_begin()
if self._train_loop_fn is None: if self._train_loop_fn is None:
self._train_loop_fn = _create_train_loop_fn( self._train_loop_fn = self.create_train_loop_fn()
self.train_step, options=self._train_options)
if self._train_iter is None: if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset) self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
...@@ -222,25 +225,6 @@ class StandardEvaluatorOptions: ...@@ -222,25 +225,6 @@ class StandardEvaluatorOptions:
use_tf_while_loop: bool = False use_tf_while_loop: bool = False
def _create_eval_loop_fn(eval_step_fn, has_state: bool,
options: StandardEvaluatorOptions):
"""Create evaluation loop function."""
if options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs. """Implements the standard functionality of AbstractEvaluator APIs.
...@@ -279,6 +263,31 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -279,6 +263,31 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
self._eval_dataset = eval_dataset self._eval_dataset = eval_dataset
self._eval_loop_fn = None self._eval_loop_fn = None
def create_eval_loop_fn(self, has_state: bool):
"""Creates an eval loop from the current step function and options.
Args:
has_state: If the step function has state, state will be kept in the loop.
Returns:
The eval loop function, i.e. wrapper of multiple eval steps.
"""
eval_step_fn = self.eval_step
if self._eval_options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if self._eval_options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
def evaluate(self, num_steps: tf.Tensor) -> Optional[runner.Output]: def evaluate(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of evaluation. """Implements `num_steps` steps of evaluation.
...@@ -302,8 +311,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): ...@@ -302,8 +311,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
has_state = outputs is not None has_state = outputs is not None
if self._eval_loop_fn is None: if self._eval_loop_fn is None:
self._eval_loop_fn = _create_eval_loop_fn( self._eval_loop_fn = self.create_eval_loop_fn(has_state)
self.eval_step, has_state=has_state, options=self._eval_options)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset) eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if self._eval_options.use_tf_while_loop and not has_state: if self._eval_options.use_tf_while_loop and not has_state:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment