Unverified Commit 05a79f5a authored by Ayush Dubey's avatar Ayush Dubey Committed by GitHub
Browse files

Add command line option for multi worker collective implementations, disable checkpointing. (#6317)

* s/CollectiveAllReduceStrategy/MultiWorkerMirroredStrategy

* More s/contrib.distribute/distribute.experimental

* Collective communication options in MultiWorkerMirroredStrategy.

* Minor fixes

* No checkpointing if multi worker.

* turn off checkpointing

* fix lint
parent 258b77cc
......@@ -489,7 +489,8 @@ def resnet_main(
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy,
session_config=session_config,
save_checkpoints_secs=60*60*24)
save_checkpoints_secs=None,
save_checkpoints_steps=None)
# Initializes model with all but the dense layer from pretrained ResNet.
if flags_obj.pretrained_model_checkpoint_path is not None:
......
......@@ -141,8 +141,13 @@ def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
flags.DEFINE_string(
name="all_reduce_alg", short_name="ara", default=None,
help=help_wrap("Defines the algorithm to use for performing all-reduce."
"See tf.contrib.distribute.AllReduceCrossTowerOps for "
"more details and available options."))
"When specified with MirroredStrategy for single "
"worker, this controls "
"tf.contrib.distribute.AllReduceCrossTowerOps. When "
"specified with MultiWorkerMirroredStrategy, this "
"controls "
"tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`."))
if tf_gpu_thread_mode:
flags.DEFINE_string(
......
......@@ -24,6 +24,12 @@ import random
import string
import tensorflow as tf
_COLLECTIVE_COMMUNICATION_OPTIONS = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
def get_distribution_strategy(distribution_strategy="default",
num_gpus=0,
......@@ -42,8 +48,10 @@ def get_distribution_strategy(distribution_strategy="default",
num_workers: Number of workers to run this model.
all_reduce_alg: Optional. Specify which algorithm to use when performing
all-reduce. See tf.contrib.distribute.AllReduceCrossDeviceOps for
available algorithms. If None, DistributionStrategy will choose based on
device topology.
available algorithms when used with `mirrored`, and
tf.distribute.experimental.CollectiveCommunication when used with
`multi_worker_mirrored`. If None, DistributionStrategy will choose based
on device topology.
Returns:
tf.distribute.DistibutionStrategy object.
......@@ -63,7 +71,13 @@ def get_distribution_strategy(distribution_strategy="default",
return None
if distribution_strategy == "multi_worker_mirrored" or num_workers > 1:
return tf.distribute.experimental.MultiWorkerMirroredStrategy()
if all_reduce_alg not in _COLLECTIVE_COMMUNICATION_OPTIONS:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_COLLECTIVE_COMMUNICATION_OPTIONS[all_reduce_alg])
if (distribution_strategy == "one_device" or
(distribution_strategy == "default" and num_gpus <= 1)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment