"...data/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0024a5c66f90c7d3d02f7ef08a773aace6deb155"
Commit 2e77bb3e authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds tf.distribute.experimental.ParameterServerStrategy support to Orbit.

PiperOrigin-RevId: 362729857
parent f7ea371e
......@@ -18,7 +18,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import functools
from absl import logging
import gin
import orbit
......@@ -84,10 +84,85 @@ class Recovery:
"%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
def init_async(self):
"""Initializes the Async Trainer base class."""
assert isinstance(self._strategy, tf.distribute.Strategy)
self._is_async = isinstance(
self._strategy, tf.distribute.experimental.ParameterServerStrategy)
self._coordinator = None
if self._is_async:
self._coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(
self._strategy))
def join(self):
"""Join all async steps. Only useful in aysnc training."""
if getattr(self, "_is_async", False):
self._coordinator.join()
def create_train_loop_fn(self):
"""Creates a eval loop from the given step function and options."""
train_loop_fn = super().create_train_loop_fn()
if getattr(self, "_is_async", False):
def _async_loop_fn(iterator, num_steps):
self._coordinator.schedule(train_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return train_loop_fn
def create_eval_loop_fn(self, has_state: bool):
"""Creates a training loop from the given step function and options."""
eval_loop_fn = super().create_eval_loop_fn(has_state)
if getattr(self, "_is_async", False):
if has_state:
raise ValueError(
"Stateful eval loop is not supported in async training.")
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
assert state is None
assert reduce_fn is None
self._coordinator.schedule(eval_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return eval_loop_fn
def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if getattr(self, "_is_async", False):
per_worker_dataset_fn = functools.partial(
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
*args, **kwargs)
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
return self._coordinator.create_per_worker_dataset(per_worker_dataset_fn)
else:
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
*args, **kwargs)
@gin.configurable
class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def __init__(self,
config: ExperimentConfig,
task: base_task.Task,
......@@ -147,9 +222,11 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
self.init_async()
if train:
train_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.task.build_inputs, self.config.task.train_data)
train_dataset = self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
self,
train_dataset,
......@@ -159,9 +236,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy, self.task.build_inputs,
self.config.task.validation_data)
eval_dataset = self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
self,
eval_dataset,
......@@ -270,6 +346,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def train_loop_end(self):
"""See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
......@@ -324,6 +401,7 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
def eval_end(self, aggregated_logs=None):
"""Processes evaluation results."""
self.join()
logs = {}
for metric in self.validation_metrics:
logs[metric.name] = metric.result()
......
......@@ -14,9 +14,13 @@
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import multiprocessing
import os
import sys
from absl.testing import parameterized
import numpy as np
import portpicker
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -26,6 +30,9 @@ from official.core import config_definitions as cfg
from official.core import train_lib
from official.utils.testing import mock_task
TPU_TEST = 'test_tpu' in sys.argv[0]
GPU_TEST = 'test_gpu' in sys.argv[0]
def all_strategy_combinations():
return combinations.combine(
......@@ -36,6 +43,113 @@ def all_strategy_combinations():
],)
def create_in_process_cluster(num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name='worker',
task_index=i,
config=worker_config,
protocol='grpc')
for i in range(num_ps):
tf.distribute.Server(
cluster_spec, job_name='ps', task_index=i, protocol='grpc')
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer='grpc')
return cluster_resolver
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class MockAsyncTrainer(trainer_lib._AsyncTrainer):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def __init__(self):
self._strategy = tf.distribute.get_strategy()
self.init_async()
self.global_step = tf.Variable(
0,
dtype=tf.int64,
name='global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
self.eval_global_step = tf.Variable(
0,
dtype=tf.int64,
name='eval_global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardTrainer.__init__(
self, train_dataset, options=trainer_lib.orbit.StandardTrainerOptions())
eval_dataset = self.distribute_dataset(dataset_fn)
trainer_lib.orbit.StandardEvaluator.__init__(
self,
eval_dataset,
options=trainer_lib.orbit.StandardEvaluatorOptions(
use_tf_while_loop=True))
def train_loop_begin(self):
self.global_step.assign(0)
def train_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self):
self.join()
return self.global_step.numpy()
def eval_begin(self):
self.eval_global_step.assign(0)
def eval_step(self, iterator):
def replica_step(_):
self.eval_global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def eval_end(self):
self.join()
return self.eval_global_step.numpy()
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
......@@ -71,6 +185,55 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
trainer = MockAsyncTrainer()
trainer.init_async()
self.assertIsInstance(
trainer._coordinator,
tf.distribute.experimental.coordinator.ClusterCoordinator)
self.assertEqual(trainer.train(tf.constant(10)), 10)
self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
def test_async_trainer_validate(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('acc', logs)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
......
......@@ -68,21 +68,6 @@ class StandardTrainerOptions:
use_tpu_summary_optimization: bool = False
def _create_train_loop_fn(train_step_fn, options: StandardTrainerOptions):
"""Creates a training loop from the given step function and options."""
if options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements standard functionality on top of the AbstractTrainer API.
......@@ -119,6 +104,25 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self._train_iter = None
self._train_loop_fn = None
def create_train_loop_fn(self):
"""Creates a training loop from the current step function and options.
Returns:
The train loop function, i.e. wrapper of multiple train steps.
"""
train_step_fn = self.train_step
if self._train_options.use_tf_while_loop:
loop_fn = loop_fns.create_tf_while_loop_fn(train_step_fn)
if self._train_options.use_tpu_summary_optimization:
loop_fn = loop_fns.LoopFnWithSummaries(loop_fn)
else:
loop_fn = tf.function(loop_fn)
else:
if self._train_options.use_tf_function:
train_step_fn = tf.function(train_step_fn)
loop_fn = loop_fns.create_loop_fn(train_step_fn)
return loop_fn
def train(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of training.
......@@ -132,8 +136,7 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self.train_loop_begin()
if self._train_loop_fn is None:
self._train_loop_fn = _create_train_loop_fn(
self.train_step, options=self._train_options)
self._train_loop_fn = self.create_train_loop_fn()
if self._train_iter is None:
self._train_iter = tf.nest.map_structure(iter, self.train_dataset)
......@@ -222,25 +225,6 @@ class StandardEvaluatorOptions:
use_tf_while_loop: bool = False
def _create_eval_loop_fn(eval_step_fn, has_state: bool,
options: StandardEvaluatorOptions):
"""Create evaluation loop function."""
if options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.
......@@ -279,6 +263,31 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
self._eval_dataset = eval_dataset
self._eval_loop_fn = None
def create_eval_loop_fn(self, has_state: bool):
"""Creates an eval loop from the current step function and options.
Args:
has_state: If the step function has state, state will be kept in the loop.
Returns:
The eval loop function, i.e. wrapper of multiple eval steps.
"""
eval_step_fn = self.eval_step
if self._eval_options.use_tf_while_loop:
# TODO(b/176126742): tf.while_loop doesn't support `None` as a loop input
# even when it is not used inside the loop. To workaround this limitation,
# we have to build two tf.functions for it.
if has_state:
loop_fn = loop_fns.create_tf_while_loop_fn_with_state(eval_step_fn)
else:
loop_fn = loop_fns.create_tf_while_loop_fn(eval_step_fn)
loop_fn = tf.function(loop_fn)
else:
if self._eval_options.use_tf_function:
eval_step_fn = tf.function(eval_step_fn)
loop_fn = loop_fns.create_loop_fn(eval_step_fn)
return loop_fn
def evaluate(self, num_steps: tf.Tensor) -> Optional[runner.Output]:
"""Implements `num_steps` steps of evaluation.
......@@ -302,8 +311,7 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
has_state = outputs is not None
if self._eval_loop_fn is None:
self._eval_loop_fn = _create_eval_loop_fn(
self.eval_step, has_state=has_state, options=self._eval_options)
self._eval_loop_fn = self.create_eval_loop_fn(has_state)
eval_iter = tf.nest.map_structure(iter, self.eval_dataset)
if self._eval_options.use_tf_while_loop and not has_state:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment