Commit 6a0dda1f authored by Jaeman's avatar Jaeman Committed by Taylor Robie
Browse files

Fix bug on distributed training in mnist using MirroredStrategy API (#5183)

* Fix bug on distributed training in mnist using MirroredStrategy API

* Remove unnecessary codes and chagne distribution strategy source
- Remove multi-gpu
- Remove TowerOptimizer
- Change from MirroredStrategy to distribution_utils.get_distribution_strategy
parent 0d105c32
......@@ -89,6 +89,7 @@ def create_model(data_format):
def define_mnist_flags():
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
......@@ -119,10 +120,6 @@ def model_fn(features, labels, mode, params):
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
# If we are running multi-GPU, we need to wrap the optimizer.
if params.get('multi_gpu'):
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
......@@ -162,21 +159,16 @@ def run_mnist(flags_obj):
model_helpers.apply_clean(flags_obj)
model_function = model_fn
# Get number of GPUs as defined by the --num_gpus flags and the number of
# GPUs available on the machine.
num_gpus = flags_core.get_num_gpus(flags_obj)
multi_gpu = num_gpus > 1
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)
if multi_gpu:
# Validate that the batch size can be split into devices.
distribution_utils.per_device_batch_size(flags_obj.batch_size, num_gpus)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN,
devices=["/device:GPU:%d" % d for d in range(num_gpus)])
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)
data_format = flags_obj.data_format
if data_format is None:
......@@ -185,9 +177,9 @@ def run_mnist(flags_obj):
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
config=run_config,
params={
'data_format': data_format,
'multi_gpu': multi_gpu
})
# Set up training and evaluation input functions.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment