Commit b1b4c805 authored by Shining Sun's avatar Shining Sun
Browse files

Inlude the distribution_utils file

parent 424c2045
......@@ -21,7 +21,8 @@ from __future__ import print_function
import tensorflow as tf
def get_distribution_strategy(num_gpus, all_reduce_alg=None):
def get_distribution_strategy(
num_gpus, all_reduce_alg=None, use_one_device_strategy):
"""Return a DistributionStrategy for running the model.
Args:
......@@ -30,15 +31,25 @@ def get_distribution_strategy(num_gpus, all_reduce_alg=None):
See tf.contrib.distribute.AllReduceCrossDeviceOps for available
algorithms. If None, DistributionStrategy will choose based on device
topology.
use_one_device_strategy: Should only be set to Truen when num_gpus is 1.
If True, then use OneDeviceStrategy; otherwise, do not use any
distribution strategy.
Returns:
tf.contrib.distribute.DistibutionStrategy object.
"""
if num_gpus == 0:
if num_gpus == 0 and use_one_device_strategy:
return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0")
elif num_gpus == 1:
elif num_gpus == 0:
return None
elif num_gpus == 1 and use_one_device_strategy:
return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0")
else:
elif num_gpus == 1:
return None
elif use_one_device_strategy:
rase ValueError("When %d GPUs are specified, use_one_device_strategy"
" flag cannot be set to True.".format(num_gpus))
else: # num_gpus > 1 and not use_one_device_strategy
if all_reduce_alg:
return tf.contrib.distribute.MirroredStrategy(
num_gpus=num_gpus,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment