Commit fc02382c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move a R1 specific util function from common utils to R1 models.

PiperOrigin-RevId: 303767122
parent 01d1931f
......@@ -329,6 +329,37 @@ def learning_rate_with_decay(
return learning_rate_fn
def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
......@@ -620,7 +651,7 @@ def resnet_main(
return input_function(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_replica_batch_size(
batch_size=per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=num_epochs,
dtype=flags_core.get_tf_dtype(flags_obj),
......@@ -631,7 +662,7 @@ def resnet_main(
return input_function(
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_replica_batch_size(
batch_size=per_replica_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1,
dtype=flags_core.get_tf_dtype(flags_obj))
......
......@@ -562,6 +562,36 @@ def construct_estimator(flags_obj, params, schedule_manager):
},
config=run_config)
def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
def run_transformer(flags_obj):
"""Create tf.Estimator to train and evaluate transformer model.
......@@ -605,8 +635,8 @@ def run_transformer(flags_obj):
total_batch_size = params["batch_size"]
if not params["use_tpu"]:
params["batch_size"] = distribution_utils.per_replica_batch_size(
params["batch_size"], num_gpus)
params["batch_size"] = per_replica_batch_size(params["batch_size"],
num_gpus)
schedule_manager = schedule.Manager(
train_steps=flags_obj.train_steps,
......
......@@ -157,37 +157,6 @@ def get_distribution_strategy(distribution_strategy="mirrored",
"Unrecognized Distribution Strategy: %r" % distribution_strategy)
def per_replica_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
# The `SyntheticDataset` is a temporary solution for generating synthetic data
# directly on devices. It is only useful for Keras with Distribution
# Strategies. We will have better support in `tf.data` or Distribution Strategy
......
......@@ -45,21 +45,5 @@ class GetDistributionStrategyTest(tf.test.TestCase):
self.assertIn('GPU', device)
class PerReplicaBatchSizeTest(tf.test.TestCase):
"""Tests for per_replica_batch_size."""
def test_batch_size(self):
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=0), 147)
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=1), 147)
self.assertEquals(
distribution_utils.per_replica_batch_size(147, num_gpus=7), 21)
def test_batch_size_with_remainder(self):
with self.assertRaises(ValueError):
distribution_utils.per_replica_batch_size(147, num_gpus=5)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment