Commit fd1d1780 authored by guptapriya's avatar guptapriya Committed by Taylor Robie
Browse files

Add minor performance improvements to resnet input pipeline (#4340)

* Remove one hot labels, Add drop_remainder to batch, Use parallel interleve in imagenet dataset.

* minor lint fix

* Don't try to read the files twice...

* Add explanation for cycle_length
parent 419bc6e3
......@@ -73,7 +73,6 @@ def parse_record(raw_record, is_training):
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, _NUM_CLASSES)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
......
......@@ -64,13 +64,13 @@ class BaseTest(tf.test.TestCase):
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()
self.assertAllEqual(label.shape, (10,))
self.assertAllEqual(label.shape, ())
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
with self.test_session() as sess:
image, label = sess.run([image, label])
self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))
self.assertEqual(label, 7)
for row in image:
for pixel in row:
......
......@@ -39,7 +39,7 @@ _NUM_IMAGES = {
}
_NUM_TRAIN_FILES = 1024
_SHUFFLE_BUFFER = 1500
_SHUFFLE_BUFFER = 10000
DATASET_NAME = 'ImageNet'
......@@ -152,8 +152,6 @@ def parse_record(raw_record, is_training):
num_channels=_NUM_CHANNELS,
is_training=is_training)
label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)
return image, label
......@@ -176,8 +174,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
# Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset)
# Convert to individual records.
# cycle_length = 10 means 10 files will be read and deserialized in parallel.
# This number is low enough to not cause too much contention on small systems
# but high enough to provide the benefits of parallelization. You may want
# to increase this number if you have a large number of CPU cores.
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
......
......@@ -79,7 +79,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training),
batch_size=batch_size,
num_parallel_batches=1))
num_parallel_batches=1,
drop_remainder=True))
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
......@@ -111,7 +112,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
"""
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32)
labels = tf.zeros((batch_size), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat()
return input_fn
......@@ -227,8 +228,8 @@ def resnet_model_fn(features, labels, mode, model_class,
})
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
......@@ -282,8 +283,7 @@ def resnet_model_fn(features, labels, mode, model_class,
train_op = None
if not tf.contrib.distribute.has_distribution_strategy():
accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment