"...resnet50_tensorflow.git" did not exist on "5e8a51fdbfd2b3181b65bd07a4ec23e2dff5b5d7"
Unverified Commit 4a1fba0b authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Add num_packs flag for MirroredStrategy's cross device ops. (#6676)

* Add num_packs flag for MirroredStrategy's cross device ops.

* fix parens

* Fix lint errors and make all_reduce_alg more robust.

* Set default num_packs to 1
parent 9b17d796
...@@ -550,7 +550,8 @@ def resnet_main( ...@@ -550,7 +550,8 @@ def resnet_main(
distribution_strategy=flags_obj.distribution_strategy, distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj), num_gpus=flags_core.get_num_gpus(flags_obj),
num_workers=num_workers, num_workers=num_workers,
all_reduce_alg=flags_obj.all_reduce_alg) all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs)
# Creates a `RunConfig` that checkpoints every 24 hours which essentially # Creates a `RunConfig` that checkpoints every 24 hours which essentially
# results in checkpoints determined only by `epochs_between_evals`. # results in checkpoints determined only by `epochs_between_evals`.
......
...@@ -47,7 +47,8 @@ def get_loss_scale(flags_obj): ...@@ -47,7 +47,8 @@ def get_loss_scale(flags_obj):
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True, synthetic_data=True, max_train_steps=True, dtype=True,
all_reduce_alg=True, tf_gpu_thread_mode=False, all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
datasets_num_parallel_batches=False, datasets_num_parallel_batches=False,
dynamic_loss_scale=False): dynamic_loss_scale=False):
...@@ -62,6 +63,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -62,6 +63,8 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
of training steps of training steps
dtype: Create flags for specifying dtype. dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu. all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross
device ops.
tf_gpu_thread_mode: gpu_private triggers us of private thread pool. tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets. datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in datasets_num_parallel_batches: Determines how many batches to process in
...@@ -176,6 +179,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -176,6 +179,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
"tf.distribute.experimental.CollectiveCommunication; " "tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`.")) "valid options are `ring` and `nccl`."))
if num_packs:
flags.DEFINE_integer(
name="num_packs", default=1,
help=help_wrap("Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce."))
if tf_gpu_thread_mode: if tf_gpu_thread_mode:
flags.DEFINE_string( flags.DEFINE_string(
name="tf_gpu_thread_mode", short_name="gt_mode", default=None, name="tf_gpu_thread_mode", short_name="gt_mode", default=None,
......
...@@ -24,27 +24,66 @@ import random ...@@ -24,27 +24,66 @@ import random
import string import string
import tensorflow as tf import tensorflow as tf
_COLLECTIVE_COMMUNICATION_OPTIONS = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
_MIRRORED_ALL_REDUCE_NUM_PACKS = 2 def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
_MIRRORED_ALL_REDUCE_OPTIONS = { Args:
None: None, all_reduce_alg: a string specifying which collective communication to pick,
"nccl": tf.distribute.NcclAllReduce( or None.
num_packs=_MIRRORED_ALL_REDUCE_NUM_PACKS),
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce( Returns:
num_packs=_MIRRORED_ALL_REDUCE_NUM_PACKS) tf.distribute.experimental.CollectiveCommunication object
}
Raises:
ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def get_distribution_strategy(distribution_strategy="default", def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
num_workers=1, num_workers=1,
all_reduce_alg=None): all_reduce_alg=None,
num_packs=1):
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
...@@ -61,6 +100,8 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -61,6 +100,8 @@ def get_distribution_strategy(distribution_strategy="default",
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on "ring" and "nccl". If None, DistributionStrategy will choose based on
device topology. device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
...@@ -80,13 +121,8 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -80,13 +121,8 @@ def get_distribution_strategy(distribution_strategy="default",
return None return None
if distribution_strategy == "multi_worker_mirrored" or num_workers > 1: if distribution_strategy == "multi_worker_mirrored" or num_workers > 1:
if all_reduce_alg not in _COLLECTIVE_COMMUNICATION_OPTIONS:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format(
all_reduce_alg))
return tf.distribute.experimental.MultiWorkerMirroredStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_COLLECTIVE_COMMUNICATION_OPTIONS[all_reduce_alg]) communication=_collective_communication(all_reduce_alg))
if (distribution_strategy == "one_device" or if (distribution_strategy == "one_device" or
(distribution_strategy == "default" and num_gpus <= 1)): (distribution_strategy == "default" and num_gpus <= 1)):
...@@ -104,14 +140,9 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -104,14 +140,9 @@ def get_distribution_strategy(distribution_strategy="default",
devices = ["device:CPU:0"] devices = ["device:CPU:0"]
else: else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)] devices = ["device:GPU:%d" % i for i in range(num_gpus)]
if all_reduce_alg not in _MIRRORED_ALL_REDUCE_OPTIONS:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format(
all_reduce_alg))
return tf.distribute.MirroredStrategy( return tf.distribute.MirroredStrategy(
devices=devices, devices=devices,
cross_device_ops=_MIRRORED_ALL_REDUCE_OPTIONS[all_reduce_alg]) cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy() return tf.distribute.experimental.ParameterServerStrategy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment