Unverified Commit f2e90945 authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Multi-worker support for Resnet. (#6206)

* Update official resnet for multi worker training with distribution strategies.

* Fixes for multi worker training.

* Fix call to `get_distribution_strategy`.

* Undo test change.

* Fix spacing.

* Move cluster configuration to distribution_utils.

* Move train_and_evaluate out of loop.  Also, update docstrings for multi-worker flags and add use_train_and_evaluate flag.

* Update distribution_strategy flag to match exported name for collective strategy.
parent 9bf1fd02
...@@ -467,6 +467,10 @@ def resnet_main( ...@@ -467,6 +467,10 @@ def resnet_main(
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)
# Configures cluster spec for distribution strategy.
num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
flags_obj.task_index)
# Creates session config. allow_soft_placement = True, is required for # Creates session config. allow_soft_placement = True, is required for
# multi-GPU and is not harmful for other modes. # multi-GPU and is not harmful for other modes.
session_config = tf.compat.v1.ConfigProto( session_config = tf.compat.v1.ConfigProto(
...@@ -477,6 +481,7 @@ def resnet_main( ...@@ -477,6 +481,7 @@ def resnet_main(
distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj), num_gpus=flags_core.get_num_gpus(flags_obj),
num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg) all_reduce_alg=flags_obj.all_reduce_alg)
# Creates a `RunConfig` that checkpoints every 24 hours which essentially # Creates a `RunConfig` that checkpoints every 24 hours which essentially
...@@ -546,46 +551,61 @@ def resnet_main( ...@@ -546,46 +551,61 @@ def resnet_main(
num_epochs=1, num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
if flags_obj.eval_only or not flags_obj.train_epochs: train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
# If --eval_only is set, perform a single loop with zero train epochs. flags_obj.train_epochs)
schedule, n_loops = [0], 1
else: use_train_and_evaluate = flags_obj.use_train_and_evaluate or isinstance(
# Compute the number of times to loop while training. All but the last distribution_strategy, tf.contrib.distribute.CollectiveAllReduceStrategy)
# pass will train for `epochs_between_evals` epochs, while the last will if use_train_and_evaluate:
# train for the number needed to reach `training_epochs`. For instance if train_spec = tf.estimator.TrainSpec(
# train_epochs = 25 and epochs_between_evals = 10 input_fn=lambda: input_fn_train(train_epochs), hooks=train_hooks,
# schedule will be set to [10, 10, 5]. That is to say, the loop will: max_steps=flags_obj.max_train_steps)
# Train for 10 epochs and then evaluate. eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval,
# Train for another 10 epochs and then evaluate. steps=flags_obj.max_train_steps)
# Train for a final 5 epochs (to reach 25 epochs) and then evaluate. tf.compat.v1.logging.info('Starting to train and evaluate.')
n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) eval_results, _ = tf.estimator.train_and_evaluate(classifier, train_spec,
schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))] eval_spec)
schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1]) # over counting.
for cycle_index, num_train_epochs in enumerate(schedule):
tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
int(n_loops))
if num_train_epochs:
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
tf.compat.v1.logging.info('Starting to evaluate.')
# flags_obj.max_train_steps is generally associated with testing and
# profiling. As a result it is frequently called with synthetic data, which
# will iterate forever. Passing steps=flags_obj.max_train_steps allows the
# eval (which is generally unimportant in those circumstances) to terminate.
# Note that eval will run for max_train_steps each loop, regardless of the
# global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
else:
if model_helpers.past_stop_threshold( if train_epochs == 0:
flags_obj.stop_threshold, eval_results['accuracy']): # If --eval_only is set, perform a single loop with zero train epochs.
break schedule, n_loops = [0], 1
else:
# Compute the number of times to loop while training. All but the last
# pass will train for `epochs_between_evals` epochs, while the last will
# train for the number needed to reach `training_epochs`. For instance if
# train_epochs = 25 and epochs_between_evals = 10
# schedule will be set to [10, 10, 5]. That is to say, the loop will:
# Train for 10 epochs and then evaluate.
# Train for another 10 epochs and then evaluate.
# Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting.
for cycle_index, num_train_epochs in enumerate(schedule):
tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
int(n_loops))
if num_train_epochs:
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)
# flags_obj.max_train_steps is generally associated with testing and
# profiling. As a result it is frequently called with synthetic data,
# which will iterate forever. Passing steps=flags_obj.max_train_steps
# allows the eval (which is generally unimportant in those circumstances)
# to terminate. Note that eval will run for max_train_steps each loop,
# regardless of the global_step count.
tf.compat.v1.logging.info('Starting to evaluate.')
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags_obj.max_train_steps)
benchmark_logger.log_evaluation_result(eval_results)
if model_helpers.past_stop_threshold(
flags_obj.stop_threshold, eval_results['accuracy']):
break
if flags_obj.export_dir is not None: if flags_obj.export_dir is not None:
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
...@@ -644,6 +664,22 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -644,6 +664,22 @@ def define_resnet_flags(resnet_size_choices=None):
'the expense of image resize/cropping being done as part of model ' 'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot ' 'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.')) 'be used for CIFAR.'))
flags.DEFINE_boolean(
name='use_train_and_evaluate', default=False,
help=flags_core.help_wrap(
'If True, uses `tf.estimator.train_and_evaluate` for the training '
'and evaluation loop, instead of separate calls to `classifier.train '
'and `classifier.evaluate`, which is the default behavior.'))
flags.DEFINE_string(
name='worker_hosts', default=None,
help=flags_core.help_wrap(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this flag.'))
flags.DEFINE_integer(
name='task_index', default=-1,
help=flags_core.help_wrap('If multi-worker training, the task_index of '
'this worker.'))
choice_kwargs = dict( choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50', name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.')) help=flags_core.help_wrap('The size of the ResNet model to use.'))
......
...@@ -103,8 +103,9 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, ...@@ -103,8 +103,9 @@ def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
name="num_gpus", short_name="ng", name="num_gpus", short_name="ng",
default=1 if tf.test.is_gpu_available() else 0, default=1 if tf.test.is_gpu_available() else 0,
help=help_wrap( help=help_wrap(
"How many GPUs to use with the DistributionStrategies API. The " "How many GPUs to use at each worker with the "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise.")) "DistributionStrategies API. The default is 1 if TensorFlow can "
"detect a GPU, and 0 otherwise."))
if hooks: if hooks:
# Construct a pretty summary of hooks. # Construct a pretty summary of hooks.
......
...@@ -18,6 +18,8 @@ from __future__ import absolute_import ...@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import json
import os
import random import random
import string import string
import tensorflow as tf import tensorflow as tf
...@@ -25,16 +27,19 @@ import tensorflow as tf ...@@ -25,16 +27,19 @@ import tensorflow as tf
def get_distribution_strategy(distribution_strategy="default", def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
num_workers=1,
all_reduce_alg=None): all_reduce_alg=None):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
distribution_strategy: a string specify which distribution strategy to use. distribution_strategy: a string specify which distribution strategy to use.
Accepted values are 'off', 'default', 'one_device', 'mirrored', Accepted values are 'off', 'default', 'one_device', 'mirrored',
'parameter_server', 'collective', case insensitive. 'off' means not to use 'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means
Distribution Strategy; 'default' means to choose from `MirroredStrategy` not to use Distribution Strategy; 'default' means to choose from
or `OneDeviceStrategy` according to the number of GPUs." `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
according to the number of GPUs and number of workers.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specify which algorithm to use when performing all_reduce_alg: Optional. Specify which algorithm to use when performing
all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for
available algorithms. If None, DistributionStrategy will choose based on available algorithms. If None, DistributionStrategy will choose based on
...@@ -51,11 +56,16 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -51,11 +56,16 @@ def get_distribution_strategy(distribution_strategy="default",
distribution_strategy = distribution_strategy.lower() distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off": if distribution_strategy == "off":
if num_gpus > 1: if num_gpus > 1 or num_workers > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy flag " raise ValueError(
"cannot be set to 'off'.".format(num_gpus)) "When {} GPUs and {} workers are specified, distribution_strategy "
"flag cannot be set to 'off'.".format(num_gpus, num_workers))
return None return None
if distribution_strategy == "multi_worker_mirrored" or num_workers > 1:
return tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
if (distribution_strategy == "one_device" or if (distribution_strategy == "one_device" or
(distribution_strategy == "default" and num_gpus <= 1)): (distribution_strategy == "default" and num_gpus <= 1)):
if num_gpus == 0: if num_gpus == 0:
...@@ -80,10 +90,6 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -80,10 +90,6 @@ def get_distribution_strategy(distribution_strategy="default",
else: else:
return tf.distribute.MirroredStrategy(devices=devices) return tf.distribute.MirroredStrategy(devices=devices)
if distribution_strategy == "collective":
return tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.contrib.distribute.ParameterServerStrategy( return tf.contrib.distribute.ParameterServerStrategy(
num_gpus_per_worker=num_gpus) num_gpus_per_worker=num_gpus)
...@@ -197,3 +203,32 @@ def undo_set_up_synthetic_data(): ...@@ -197,3 +203,32 @@ def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
else: else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*') print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
if tf_config:
num_workers = len(tf_config['cluster']['worker'])
elif worker_hosts:
workers = worker_hosts.split(',')
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError('Must specify task_index when number of workers > 1')
task_index = 0 if num_workers == 1 else task_index
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': workers
},
'task': {'type': 'worker', 'index': task_index}
})
else:
num_workers = 1
return num_workers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment