Unverified Commit b00783d7 authored by Igor's avatar Igor Committed by GitHub
Browse files

Replace per_device with per_replica and PerDevice with PerReplica, because the...

Replace per_device with per_replica and PerDevice with PerReplica, because the PerDevice concept was renamed and doesn't exist anymore. (#6693)

* Replace per_device with per_replica and PerDevice with PerReplica, because the PerReplica concept was renamed and doesn't exist anymore.
parent 294660bd
......@@ -246,7 +246,7 @@ class DatasetManager(object):
to the TPU through a StreamingFilesDataset.
Args:
batch_size: The per-device batch size of the dataset.
batch_size: The per-replica batch size of the dataset.
epochs_between_evals: How many epochs worth of data to yield.
(Generator mode only.)
"""
......
......@@ -608,7 +608,7 @@ def resnet_main(
return input_function(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
batch_size=distribution_utils.per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
......@@ -620,7 +620,7 @@ def resnet_main(
return input_function(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
batch_size=distribution_utils.per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
......
......@@ -564,7 +564,7 @@ def run_transformer(flags_obj):
else params["default_batch_size"]))
if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size(
params["batch_size"] = distribution_utils.per_replica_batch_size(
params["batch_size"], num_gpus)
schedule_manager = schedule.Manager(
......
......@@ -151,7 +151,7 @@ def get_distribution_strategy(distribution_strategy="default",
"Unrecognized Distribution Strategy: %r" % distribution_strategy)
def per_device_batch_size(batch_size, num_gpus):
def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
......
......@@ -45,20 +45,20 @@ class GetDistributionStrategyTest(tf.test.TestCase):
self.assertIn('GPU', device)
class PerDeviceBatchSizeTest(tf.test.TestCase):
"""Tests for per_device_batch_size."""
class PerReplicaBatchSizeTest(tf.test.TestCase):
"""Tests for per_replica_batch_size."""
def test_batch_size(self):
self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=0), 147)
distribution_utils.per_replica_batch_size(147, num_gpus=0), 147)
self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=1), 147)
distribution_utils.per_replica_batch_size(147, num_gpus=1), 147)
self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=7), 21)
distribution_utils.per_replica_batch_size(147, num_gpus=7), 21)
def test_batch_size_with_remainder(self):
with self.assertRaises(ValueError):
distribution_utils.per_device_batch_size(147, num_gpus=5)
distribution_utils.per_replica_batch_size(147, num_gpus=5)
if __name__ == "__main__":
......
......@@ -257,16 +257,16 @@ def run_deep_speech(_):
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
per_device_batch_size = distribution_utils.per_device_batch_size(
per_replica_batch_size = distribution_utils.per_replica_batch_size(
flags_obj.batch_size, num_gpus)
def input_fn_train():
return dataset.input_fn(
per_device_batch_size, train_speech_dataset)
per_replica_batch_size, train_speech_dataset)
def input_fn_eval():
return dataset.input_fn(
per_device_batch_size, eval_speech_dataset)
per_replica_batch_size, eval_speech_dataset)
total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment