"docs/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "c7dac10bfd6b20ed9945b2810e0b5bfc4728fab3"
Unverified Commit b00783d7 authored by Igor's avatar Igor Committed by GitHub
Browse files

Replace per_device with per_replica and PerDevice with PerReplica, because the...

Replace per_device with per_replica and PerDevice with PerReplica, because the PerDevice concept was renamed and doesn't exist anymore. (#6693)

* Replace per_device with per_replica and PerDevice with PerReplica, because the PerReplica concept was renamed and doesn't exist anymore.
parent 294660bd
...@@ -246,7 +246,7 @@ class DatasetManager(object): ...@@ -246,7 +246,7 @@ class DatasetManager(object):
to the TPU through a StreamingFilesDataset. to the TPU through a StreamingFilesDataset.
Args: Args:
batch_size: The per-device batch size of the dataset. batch_size: The per-replica batch size of the dataset.
epochs_between_evals: How many epochs worth of data to yield. epochs_between_evals: How many epochs worth of data to yield.
(Generator mode only.) (Generator mode only.)
""" """
......
...@@ -608,7 +608,7 @@ def resnet_main( ...@@ -608,7 +608,7 @@ def resnet_main(
return input_function( return input_function(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size( batch_size=distribution_utils.per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj),
...@@ -620,7 +620,7 @@ def resnet_main( ...@@ -620,7 +620,7 @@ def resnet_main(
return input_function( return input_function(
is_training=False, is_training=False,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size( batch_size=distribution_utils.per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1, num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
......
...@@ -564,7 +564,7 @@ def run_transformer(flags_obj): ...@@ -564,7 +564,7 @@ def run_transformer(flags_obj):
else params["default_batch_size"])) else params["default_batch_size"]))
if not params["use_tpu"]: if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_device_batch_size( params["batch_size"] = distribution_utils.per_replica_batch_size(
params["batch_size"], num_gpus) params["batch_size"], num_gpus)
schedule_manager = schedule.Manager( schedule_manager = schedule.Manager(
......
...@@ -151,7 +151,7 @@ def get_distribution_strategy(distribution_strategy="default", ...@@ -151,7 +151,7 @@ def get_distribution_strategy(distribution_strategy="default",
"Unrecognized Distribution Strategy: %r" % distribution_strategy) "Unrecognized Distribution Strategy: %r" % distribution_strategy)
def per_device_batch_size(batch_size, num_gpus): def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. """For multi-gpu, batch-size must be a multiple of the number of GPUs.
......
...@@ -45,20 +45,20 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -45,20 +45,20 @@ class GetDistributionStrategyTest(tf.test.TestCase):
self.assertIn('GPU', device) self.assertIn('GPU', device)
class PerDeviceBatchSizeTest(tf.test.TestCase): class PerReplicaBatchSizeTest(tf.test.TestCase):
"""Tests for per_device_batch_size.""" """Tests for per_replica_batch_size."""
def test_batch_size(self): def test_batch_size(self):
self.assertEquals( self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=0), 147) distribution_utils.per_replica_batch_size(147, num_gpus=0), 147)
self.assertEquals( self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=1), 147) distribution_utils.per_replica_batch_size(147, num_gpus=1), 147)
self.assertEquals( self.assertEquals(
distribution_utils.per_device_batch_size(147, num_gpus=7), 21) distribution_utils.per_replica_batch_size(147, num_gpus=7), 21)
def test_batch_size_with_remainder(self): def test_batch_size_with_remainder(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
distribution_utils.per_device_batch_size(147, num_gpus=5) distribution_utils.per_replica_batch_size(147, num_gpus=5)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -257,16 +257,16 @@ def run_deep_speech(_): ...@@ -257,16 +257,16 @@ def run_deep_speech(_):
model_dir=flags_obj.model_dir, model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size) batch_size=flags_obj.batch_size)
per_device_batch_size = distribution_utils.per_device_batch_size( per_replica_batch_size = distribution_utils.per_replica_batch_size(
flags_obj.batch_size, num_gpus) flags_obj.batch_size, num_gpus)
def input_fn_train(): def input_fn_train():
return dataset.input_fn( return dataset.input_fn(
per_device_batch_size, train_speech_dataset) per_replica_batch_size, train_speech_dataset)
def input_fn_eval(): def input_fn_eval():
return dataset.input_fn( return dataset.input_fn(
per_device_batch_size, eval_speech_dataset) per_replica_batch_size, eval_speech_dataset)
total_training_cycle = (flags_obj.train_epochs // total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals) flags_obj.epochs_between_evals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment