"python/vscode:/vscode.git/clone" did not exist on "ce6b17c0f94e6bf53633c8f324176a891e67fa7f"
Unverified Commit c3b4ffc5 authored by moneypi's avatar moneypi Committed by GitHub
Browse files

fix no attribute 'per_replica_batch_size' (#8406)

parent a174bf5b
......@@ -209,6 +209,35 @@ def generate_dataset(data_dir):
speech_dataset = dataset.DeepSpeechDataset(train_data_conf)
return speech_dataset
def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that distribution strategy handles this automatically when used with
Keras. For using with Estimator, we need to get per GPU batch.
Args:
batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises:
ValueError: if batch_size is not divisible by number of devices
"""
if num_gpus <= 1:
return batch_size
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} '
'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
return int(batch_size / num_gpus)
def run_deep_speech(_):
"""Run deep speech training and eval loop."""
......@@ -257,8 +286,7 @@ def run_deep_speech(_):
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)
per_replica_batch_size = distribution_utils.per_replica_batch_size(
flags_obj.batch_size, num_gpus)
per_replica_batch_size = per_device_batch_size(flags_obj.batch_size, num_gpus)
def input_fn_train():
return dataset.input_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment