Unverified Commit 05a79f5a authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Add command line option for multi worker collective implementations, disable checkpointing. (#6317)

* s/CollectiveAllReduceStrategy/MultiWorkerMirroredStrategy

* More s/contrib.distribute/distribute.experimental

* Collective communication options in MultiWorkerMirroredStrategy.

* Minor fixes

* No checkpointing if multi worker.

* turn off checkpointing

* fix lint
parent 258b77cc
...@@ -489,7 +489,8 @@ def resnet_main( ...@@ -489,7 +489,8 @@ def resnet_main(
run_config = tf.estimator.RunConfig( run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, train_distribute=distribution_strategy,
session_config=session_config, session_config=session_config,
save_checkpoints_secs=60*60*24) save_checkpoints_secs=None,
save_checkpoints_steps=None)
# Initializes model with all but the dense layer from pretrained ResNet. # Initializes model with all but the dense layer from pretrained ResNet.
if flags_obj.pretrained_model_checkpoint_path is not None: if flags_obj.pretrained_model_checkpoint_path is not None:
......
...@@ -141,8 +141,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, ...@@ -141,8 +141,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
flags.DEFINE_string( flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None, name="all_reduce_alg", short_name="ara", default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce." help=help_wrap("Defines the algorithm to use for performing all-reduce."
"See tf.contrib.distribute.AllReduceCrossTowerOps for " "When specified with MirroredStrategy for single "
"more details and available options.")) "worker, this controls "
"tf.contrib.distribute.AllReduceCrossTowerOps. When "
"specified with MultiWorkerMirroredStrategy, this "
"controls "
"tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`."))
if tf_gpu_thread_mode: if tf_gpu_thread_mode:
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -24,6 +24,12 @@ import random ...@@ -24,6 +24,12 @@ import random
import string import string
import tensorflow as tf import tensorflow as tf
_COLLECTIVE_COMMUNICATION_OPTIONS = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
def get_distribution_strategy(distribution_strategy="default", def get_distribution_strategy(distribution_strategy="default",
num_gpus=0, num_gpus=0,
...@@ -42,8 +48,10 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -42,8 +48,10 @@ def get_distribution_strategy(distribution_strategy="default",
num_workers: Number of workers to run this model. num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specify which algorithm to use when performing all_reduce_alg: Optional. Specify which algorithm to use when performing
all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for
available algorithms. If None, DistributionStrategy will choose based on available algorithms when used with `mirrored`, and
device topology. tf.distribute.experimental.CollectiveCommunication when used with
`multi_worker_mirrored`. If None, DistributionStrategy will choose based
on device topology.
Returns: Returns:
tf.distribute.DistibutionStrategy object. tf.distribute.DistibutionStrategy object.
...@@ -63,7 +71,13 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -63,7 +71,13 @@ def get_distribution_strategy(distribution_strategy="default",
return None return None
if distribution_strategy == "multi_worker_mirrored" or num_workers > 1: if distribution_strategy == "multi_worker_mirrored" or num_workers > 1:
return tf.distribute.experimental.MultiWorkerMirroredStrategy() if all_reduce_alg not in _COLLECTIVE_COMMUNICATION_OPTIONS:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_COLLECTIVE_COMMUNICATION_OPTIONS[all_reduce_alg])
if (distribution_strategy == "one_device" or if (distribution_strategy == "one_device" or
(distribution_strategy == "default" and num_gpus <= 1)): (distribution_strategy == "default" and num_gpus <= 1)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment