Commit ece99414 authored by Ayush Dubey's avatar Ayush Dubey Committed by Toby Boyd
Browse files

Remove contrib cross device ops and update all_reduce_alg options. (#6673)

* Remove contrib AllReduceCrossDeviceOps and update all_reduce_alg options with MirroredStrategy.

* cleanup
parent a7338771
...@@ -30,6 +30,16 @@ _COLLECTIVE_COMMUNICATION_OPTIONS = { ...@@ -30,6 +30,16 @@ _COLLECTIVE_COMMUNICATION_OPTIONS = {
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
} }
_MIRRORED_ALL_REDUCE_NUM_PACKS = 2
_MIRRORED_ALL_REDUCE_OPTIONS = {
None: None,
"nccl": tf.distribute.NcclAllReduce(
num_packs=_MIRRORED_ALL_REDUCE_NUM_PACKS),
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce(
num_packs=_MIRRORED_ALL_REDUCE_NUM_PACKS)
}
def get_distribution_strategy(distribution_strategy="default", def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
...@@ -38,20 +48,19 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -38,20 +48,19 @@ def get_distribution_strategy(distribution_strategy="default",
"""Return a DistributionStrategy for running the model. """Return a DistributionStrategy for running the model.
Args: Args:
distribution_strategy: a string specify which distribution strategy to use. distribution_strategy: a string specifying which distribution strategy to
Accepted values are 'off', 'default', 'one_device', 'mirrored', use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means 'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means
not to use Distribution Strategy; 'default' means to choose from not to use Distribution Strategy; 'default' means to choose from
`MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy` `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
according to the number of GPUs and number of workers. according to the number of GPUs and number of workers.
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
num_workers: Number of workers to run this model. num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specify which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for all-reduce. For `MirroredStrategy`, valid values are "nccl" and
available algorithms when used with `mirrored`, and "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
tf.distribute.experimental.CollectiveCommunication when used with "ring" and "nccl". If None, DistributionStrategy will choose based on
`multi_worker_mirrored`. If None, DistributionStrategy will choose based device topology.
on device topology.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
...@@ -74,7 +83,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -74,7 +83,7 @@ def get_distribution_strategy(distribution_strategy="default",
if all_reduce_alg not in _COLLECTIVE_COMMUNICATION_OPTIONS: if all_reduce_alg not in _COLLECTIVE_COMMUNICATION_OPTIONS:
raise ValueError( raise ValueError(
"When used with `multi_worker_mirrored`, valid values for " "When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format( "all_reduce_alg are ['ring', 'nccl']. Supplied value: {}".format(
all_reduce_alg)) all_reduce_alg))
return tf.distribute.experimental.MultiWorkerMirroredStrategy( return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_COLLECTIVE_COMMUNICATION_OPTIONS[all_reduce_alg]) communication=_COLLECTIVE_COMMUNICATION_OPTIONS[all_reduce_alg])
...@@ -95,13 +104,14 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -95,13 +104,14 @@ def get_distribution_strategy(distribution_strategy="default",
devices = ["device:CPU:0"] devices = ["device:CPU:0"]
else: else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)] devices = ["device:GPU:%d" % i for i in range(num_gpus)]
if all_reduce_alg: if all_reduce_alg not in _MIRRORED_ALL_REDUCE_OPTIONS:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"['nccl', 'hierarchical_copy']. Supplied value: {}".format(
all_reduce_alg))
return tf.distribute.MirroredStrategy( return tf.distribute.MirroredStrategy(
devices=devices, devices=devices,
cross_device_ops=tf.contrib.distribute.AllReduceCrossDeviceOps( cross_device_ops=_MIRRORED_ALL_REDUCE_OPTIONS[all_reduce_alg])
all_reduce_alg, num_packs=2))
else:
return tf.distribute.MirroredStrategy(devices=devices)
if distribution_strategy == "parameter_server": if distribution_strategy == "parameter_server":
return tf.distribute.experimental.ParameterServerStrategy() return tf.distribute.experimental.ParameterServerStrategy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment