Unverified Commit 3ae33b4d authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Fix/resnet take (#4473)

* add .take() to dataset pipeline

* delint

* address PR comments
parent 441c9bca
......@@ -107,7 +107,7 @@ def preprocess_image(image, is_training):
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
......@@ -115,6 +115,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
Returns:
A dataset that can be used for iteration.
......@@ -123,8 +124,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs,
dataset=dataset,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=_NUM_IMAGES['train'],
parse_record_fn=parse_record,
num_epochs=num_epochs,
num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
)
......
......@@ -156,7 +156,7 @@ def parse_record(raw_record, is_training):
return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
"""Input function which provides batches for train or eval.
Args:
......@@ -164,6 +164,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
Returns:
A dataset that can be used for iteration.
......@@ -184,8 +185,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
tf.data.TFRecordDataset, cycle_length=10))
return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs
dataset=dataset,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=_SHUFFLE_BUFFER,
parse_record_fn=parse_record,
num_epochs=num_epochs,
num_gpus=num_gpus,
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
)
......
......@@ -42,7 +42,8 @@ from official.utils.misc import model_helpers
# Functions for input processing.
################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1):
parse_record_fn, num_epochs=1, num_gpus=None,
examples_per_epoch=None):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
......@@ -55,6 +56,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
examples_per_epoch: The number of examples in an epoch.
Returns:
Dataset of (image, label) pairs ready for iteration.
......@@ -72,6 +75,16 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs)
if is_training and num_gpus and examples_per_epoch:
total_examples = num_epochs * examples_per_epoch
# Force the number of batches to be divisible by the number of devices.
# This prevents some devices from receiving batches while others do not,
# which can lead to a lockup. This case will soon be handled directly by
# distribution strategies, at which point this .take() operation will no
# longer be needed.
total_batches = total_examples // batch_size // num_gpus * num_gpus
dataset.take(total_batches * batch_size)
# Parse the raw records into images and labels. Testing has shown that setting
# num_parallel_batches > 1 produces no improvement in throughput, since
# batch_size is almost always much greater than the number of CPU cores.
......@@ -411,7 +424,8 @@ def resnet_main(
is_training=True, data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=flags_obj.epochs_between_evals)
num_epochs=flags_obj.epochs_between_evals,
num_gpus=flags_core.get_num_gpus(flags_obj))
def input_fn_eval():
return input_function(
......@@ -419,6 +433,7 @@ def resnet_main(
batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1)
total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment