Commit 5d7612c6 authored by Jianmin Chen's avatar Jianmin Chen Committed by Derek Murray
Browse files

Improve image processing (#45)

* improve image processing performance for Inception.
parent 84b58a60
...@@ -40,7 +40,6 @@ from __future__ import absolute_import ...@@ -40,7 +40,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
...@@ -52,6 +51,8 @@ tf.app.flags.DEFINE_integer('image_size', 299, ...@@ -52,6 +51,8 @@ tf.app.flags.DEFINE_integer('image_size', 299,
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4, tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
"""Number of preprocessing threads per tower. """ """Number of preprocessing threads per tower. """
"""Please make this a multiple of 4.""") """Please make this a multiple of 4.""")
tf.app.flags.DEFINE_integer('num_readers', 4,
"""Number of parallel readers during train.""")
# Images are preprocessed asynchronously using multiple threads specifed by # Images are preprocessed asynchronously using multiple threads specifed by
# --num_preprocss_threads and the resulting processed images are stored in a # --num_preprocss_threads and the resulting processed images are stored in a
...@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None): ...@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
images, labels = batch_inputs( images, labels = batch_inputs(
dataset, batch_size, train=False, dataset, batch_size, train=False,
num_preprocess_threads=num_preprocess_threads) num_preprocess_threads=num_preprocess_threads,
num_readers=1)
return images, labels return images, labels
...@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None): ...@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
images, labels = batch_inputs( images, labels = batch_inputs(
dataset, batch_size, train=True, dataset, batch_size, train=True,
num_preprocess_threads=num_preprocess_threads) num_preprocess_threads=num_preprocess_threads,
num_readers=FLAGS.num_readers)
return images, labels return images, labels
...@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized): ...@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
return features['image/encoded'], label, bbox, features['image/class/text'] return features['image/encoded'], label, bbox, features['image/class/text']
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None): def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
num_readers=1):
"""Contruct batches of training or evaluation examples from the image dataset. """Contruct batches of training or evaluation examples from the image dataset.
Args: Args:
...@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None): ...@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
batch_size: integer batch_size: integer
train: boolean train: boolean
num_preprocess_threads: integer, total number of preprocessing threads num_preprocess_threads: integer, total number of preprocessing threads
num_readers: integer, number of parallel readers
Returns: Returns:
images: 4-D float Tensor of a batch of images images: 4-D float Tensor of a batch of images
...@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None): ...@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
data_files = dataset.data_files() data_files = dataset.data_files()
if data_files is None: if data_files is None:
raise ValueError('No data files found for this dataset') raise ValueError('No data files found for this dataset')
filename_queue = tf.train.string_input_producer(data_files, capacity=16)
# Create filename_queue
if train:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=True,
capacity=16)
else:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=False,
capacity=1)
if num_preprocess_threads is None: if num_preprocess_threads is None:
num_preprocess_threads = FLAGS.num_preprocess_threads num_preprocess_threads = FLAGS.num_preprocess_threads
if num_preprocess_threads % 4: if num_preprocess_threads % 4:
raise ValueError('Please make num_preprocess_threads a multiple ' raise ValueError('Please make num_preprocess_threads a multiple '
'of 4 (%d % 4 != 0).', num_preprocess_threads) 'of 4 (%d % 4 != 0).', num_preprocess_threads)
# Create a subgraph with its own reader (but sharing the
# filename_queue) for each preprocessing thread.
images_and_labels = []
for thread_id in range(num_preprocess_threads):
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)
# Parse a serialized Example proto to extract the image and metadata. if num_readers is None:
image_buffer, label_index, bbox, _ = parse_example_proto( num_readers = FLAGS.num_readers
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id) if num_readers < 1:
images_and_labels.append([image, label_index]) raise ValueError('Please make num_readers at least 1')
# Approximate number of examples per shard. # Approximate number of examples per shard.
examples_per_shard = 1024 examples_per_shard = 1024
...@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None): ...@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
# The default input_queue_memory_factor is 16 implying a shuffling queue # The default input_queue_memory_factor is 16 implying a shuffling queue
# size: examples_per_shard * 16 * 1MB = 17.6GB # size: examples_per_shard * 16 * 1MB = 17.6GB
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
# Create a queue that produces the examples in batches after shuffling.
if train: if train:
images, label_index_batch = tf.train.shuffle_batch_join( examples_queue = tf.RandomShuffleQueue(
images_and_labels,
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size, capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples) min_after_dequeue=min_queue_examples,
dtypes=[tf.string])
else:
examples_queue = tf.FIFOQueue(
capacity=examples_per_shard + 3 * batch_size,
dtypes=[tf.string])
# Create multiple readers to populate the queue of examples.
if num_readers > 1:
enqueue_ops = []
for _ in range(num_readers):
reader = dataset.reader()
_, value = reader.read(filename_queue)
enqueue_ops.append(examples_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
example_serialized = examples_queue.dequeue()
else: else:
images, label_index_batch = tf.train.batch_join( reader = dataset.reader()
images_and_labels, _, example_serialized = reader.read(filename_queue)
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size) images_and_labels = []
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size)
# Reshape images into these desired dimensions. # Reshape images into these desired dimensions.
height = FLAGS.image_size height = FLAGS.image_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment