"tests/test_datasets/vscode:/vscode.git/clone" did not exist on "139c6f0c7175d7f20bd0632678e52066daa5a246"
Unverified Commit a66d4713 authored by guptapriya's avatar guptapriya Committed by GitHub
Browse files

Use core mirrored strategy in official models (#6126)

parent 2519f29b
......@@ -105,9 +105,6 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
data_format = flags_obj.data_format
if data_format is None:
data_format = ('channels_first'
......@@ -127,14 +124,14 @@ def run(flags_obj):
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
......
......@@ -227,19 +227,20 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
[height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size] + [1],
[1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.batch(batch_size)
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
......
......@@ -101,9 +101,6 @@ def run(flags_obj):
if tf.test.is_built_with_cuda() else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
# pylint: disable=protected-access
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
......@@ -117,13 +114,13 @@ def run(flags_obj):
train_input_dataset = input_fn(is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
eval_input_dataset = input_fn(is_training=False,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size,
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
......
......@@ -57,21 +57,22 @@ def get_distribution_strategy(num_gpus,
"turn_off_distribution_strategy flag cannot be set to"
"True.".format(num_gpus))
else: # num_gpus > 1 and not turn_off_distribution_strategy
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
if all_reduce_alg:
return tf.contrib.distribute.MirroredStrategy(
num_gpus=num_gpus,
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=tf.contrib.distribute.AllReduceCrossDeviceOps(
all_reduce_alg, num_packs=2))
else:
return tf.contrib.distribute.MirroredStrategy(num_gpus=num_gpus)
return tf.distribute.MirroredStrategy(devices=devices)
def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by DistributionStrategies
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment