Commit 5d7612c6 authored by Jianmin Chen's avatar Jianmin Chen Committed by Derek Murray
Browse files

Improve image processing (#45)

* improve image processing performance for Inception.
parent 84b58a60
......@@ -40,7 +40,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
......@@ -52,6 +51,8 @@ tf.app.flags.DEFINE_integer('image_size', 299,
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
"""Number of preprocessing threads per tower. """
"""Please make this a multiple of 4.""")
tf.app.flags.DEFINE_integer('num_readers', 4,
"""Number of parallel readers during train.""")
# Images are preprocessed asynchronously using multiple threads specifed by
# --num_preprocss_threads and the resulting processed images are stored in a
......@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=False,
num_preprocess_threads=num_preprocess_threads)
num_preprocess_threads=num_preprocess_threads,
num_readers=1)
return images, labels
......@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=True,
num_preprocess_threads=num_preprocess_threads)
num_preprocess_threads=num_preprocess_threads,
num_readers=FLAGS.num_readers)
return images, labels
......@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
return features['image/encoded'], label, bbox, features['image/class/text']
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
num_readers=1):
"""Contruct batches of training or evaluation examples from the image dataset.
Args:
......@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
batch_size: integer
train: boolean
num_preprocess_threads: integer, total number of preprocessing threads
num_readers: integer, number of parallel readers
Returns:
images: 4-D float Tensor of a batch of images
......@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
data_files = dataset.data_files()
if data_files is None:
raise ValueError('No data files found for this dataset')
filename_queue = tf.train.string_input_producer(data_files, capacity=16)
# Create filename_queue
if train:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=True,
capacity=16)
else:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=False,
capacity=1)
if num_preprocess_threads is None:
num_preprocess_threads = FLAGS.num_preprocess_threads
if num_preprocess_threads % 4:
raise ValueError('Please make num_preprocess_threads a multiple '
'of 4 (%d % 4 != 0).', num_preprocess_threads)
# Create a subgraph with its own reader (but sharing the
# filename_queue) for each preprocessing thread.
images_and_labels = []
for thread_id in range(num_preprocess_threads):
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
if num_readers is None:
num_readers = FLAGS.num_readers
if num_readers < 1:
raise ValueError('Please make num_readers at least 1')
# Approximate number of examples per shard.
examples_per_shard = 1024
......@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
# The default input_queue_memory_factor is 16 implying a shuffling queue
# size: examples_per_shard * 16 * 1MB = 17.6GB
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
# Create a queue that produces the examples in batches after shuffling.
if train:
images, label_index_batch = tf.train.shuffle_batch_join(
images_and_labels,
batch_size=batch_size,
examples_queue = tf.RandomShuffleQueue(
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
min_after_dequeue=min_queue_examples,
dtypes=[tf.string])
else:
examples_queue = tf.FIFOQueue(
capacity=examples_per_shard + 3 * batch_size,
dtypes=[tf.string])
# Create multiple readers to populate the queue of examples.
if num_readers > 1:
enqueue_ops = []
for _ in range(num_readers):
reader = dataset.reader()
_, value = reader.read(filename_queue)
enqueue_ops.append(examples_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
example_serialized = examples_queue.dequeue()
else:
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size)
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)
images_and_labels = []
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size)
# Reshape images into these desired dimensions.
height = FLAGS.image_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment