Commit 56b5d037 authored by pierrot0's avatar pierrot0 Committed by Toby Boyd
Browse files

TF cifar10 cnn tutorial: use tensorflow-datasets to load the data. (#5906)

* TF cifar10 cnn tutorial: use tensorflow-datasets to load the data.

* Load cifar10 in memory.

* Fix imports

* More import fixes
parent d11aa330
...@@ -35,12 +35,8 @@ from __future__ import absolute_import ...@@ -35,12 +35,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import re import re
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf import tensorflow as tf
import cifar10_input import cifar10_input
...@@ -50,8 +46,6 @@ FLAGS = tf.app.flags.FLAGS ...@@ -50,8 +46,6 @@ FLAGS = tf.app.flags.FLAGS
# Basic model parameters. # Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128, tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""") """Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False, tf.app.flags.DEFINE_boolean('use_fp16', False,
"""Train the model using fp16.""") """Train the model using fp16.""")
...@@ -73,8 +67,6 @@ INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. ...@@ -73,8 +67,6 @@ INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# names of the summaries when visualizing a model. # names of the summaries when visualizing a model.
TOWER_NAME = 'tower' TOWER_NAME = 'tower'
DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
def _activation_summary(x): def _activation_summary(x):
"""Helper to create summaries for activations. """Helper to create summaries for activations.
...@@ -91,8 +83,7 @@ def _activation_summary(x): ...@@ -91,8 +83,7 @@ def _activation_summary(x):
# session. This helps the clarity of presentation on tensorboard. # session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x) tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.scalar(tensor_name + '/sparsity', tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
tf.nn.zero_fraction(x))
def _variable_on_cpu(name, shape, initializer): def _variable_on_cpu(name, shape, initializer):
...@@ -145,15 +136,8 @@ def distorted_inputs(): ...@@ -145,15 +136,8 @@ def distorted_inputs():
Returns: Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
""" """
if not FLAGS.data_dir: images, labels = cifar10_input.distorted_inputs(batch_size=FLAGS.batch_size)
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=FLAGS.batch_size)
if FLAGS.use_fp16: if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16) labels = tf.cast(labels, tf.float16)
...@@ -169,15 +153,8 @@ def inputs(eval_data): ...@@ -169,15 +153,8 @@ def inputs(eval_data):
Returns: Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
""" """
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.inputs(eval_data=eval_data, images, labels = cifar10_input.inputs(eval_data=eval_data,
data_dir=data_dir,
batch_size=FLAGS.batch_size) batch_size=FLAGS.batch_size)
if FLAGS.use_fp16: if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) images = tf.cast(images, tf.float16)
...@@ -240,7 +217,7 @@ def inference(images): ...@@ -240,7 +217,7 @@ def inference(images):
# local3 # local3
with tf.variable_scope('local3') as scope: with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply. # Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1]) reshape = tf.keras.layers.Flatten()(pool2)
dim = reshape.get_shape()[1].value dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384], weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004) stddev=0.04, wd=0.004)
...@@ -374,24 +351,3 @@ def train(total_loss, global_step): ...@@ -374,24 +351,3 @@ def train(total_loss, global_step):
variables_averages_op = variable_averages.apply(tf.trainable_variables()) variables_averages_op = variable_averages.apply(tf.trainable_variables())
return variables_averages_op return variables_averages_op
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
if not os.path.exists(extracted_dir_path):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
...@@ -51,12 +51,12 @@ tf.app.flags.DEFINE_string('eval_data', 'test', ...@@ -51,12 +51,12 @@ tf.app.flags.DEFINE_string('eval_data', 'test',
"""Either 'test' or 'train_eval'.""") """Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
"""Directory where to read model checkpoints.""") """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, tf.app.flags.DEFINE_integer('eval_interval_secs', 5,
"""How often to run the eval.""") """How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 10000, tf.app.flags.DEFINE_integer('num_examples', 1000,
"""Number of examples to run.""") """Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False, tf.app.flags.DEFINE_boolean('run_once', False,
"""Whether to run eval only once.""") """Whether to run eval only once.""")
def eval_once(saver, summary_writer, top_k_op, summary_op): def eval_once(saver, summary_writer, top_k_op, summary_op):
...@@ -89,7 +89,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op): ...@@ -89,7 +89,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True, threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True)) start=True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) num_iter = int(math.ceil(float(FLAGS.num_examples) / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions. true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size total_sample_count = num_iter * FLAGS.batch_size
step = 0 step = 0
...@@ -146,7 +146,6 @@ def evaluate(): ...@@ -146,7 +146,6 @@ def evaluate():
def main(argv=None): # pylint: disable=unused-argument def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.eval_dir): if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir) tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir) tf.gfile.MakeDirs(FLAGS.eval_dir)
......
...@@ -19,10 +19,8 @@ from __future__ import absolute_import ...@@ -19,10 +19,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds
# Process images of this size. Note that this differs from the original CIFAR # Process images of this size. Note that this differs from the original CIFAR
# image size of 32 x 32. If one alters this number, then the entire model # image size of 32 x 32. If one alters this number, then the entire model
...@@ -35,227 +33,74 @@ NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 ...@@ -35,227 +33,74 @@ NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
def read_cifar10(filename_queue): def _get_images_labels(batch_size, split, distords=False):
"""Reads and parses examples from CIFAR10 data files. """Returns Dataset for given split."""
dataset = tfds.load(name='cifar10', split=split)
Recommendation: if you want N-way read parallelism, call this function scope = 'data_augmentation' if distords else 'input'
N times. This will give you N independent Readers reading different with tf.name_scope(scope):
files & positions within those files, which will give better mixing of dataset = dataset.map(DataPreprocessor(distords), num_parallel_calls=10)
examples. # Dataset is small enough to be fully loaded on memory:
dataset = dataset.prefetch(-1)
Args: dataset = dataset.repeat().batch(batch_size)
filename_queue: A queue of strings with the filenames to read from. iterator = dataset.make_one_shot_iterator()
images_labels = iterator.get_next()
Returns: images, labels = images_labels['input'], images_labels['target']
An object representing a single example, with the following fields:
height: number of rows in the result (32)
width: number of columns in the result (32)
depth: number of color channels in the result (3)
key: a scalar string Tensor describing the filename & record number
for this example.
label: an int32 Tensor with the label in the range 0..9.
uint8image: a [height, width, depth] uint8 Tensor with the image data
"""
class CIFAR10Record(object):
pass
result = CIFAR10Record()
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1 # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
# Every record consists of a label followed by the image, with a
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes
# Read a record, getting filenames from the filename_queue. No
# header or footer in the CIFAR-10 format, so we leave header_bytes
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8 that is record_bytes long.
record_bytes = tf.decode_raw(value, tf.uint8)
# The first bytes represent the label, which we convert from uint8->int32.
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(
tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[result.depth, result.height, result.width])
# Convert from [depth, height, width] to [height, width, depth].
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size, shuffle):
"""Construct a queued batch of images and labels.
Args:
image: 3-D Tensor of [height, width, 3] of type.float32.
label: 1-D Tensor of type.int32
min_queue_examples: int32, minimum number of samples to retain
in the queue that provides of batches of examples.
batch_size: Number of images per batch.
shuffle: boolean indicating whether to use a shuffling queue.
Returns:
images: Images. 4D tensor of [batch_size, height, width, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
# Display the training images in the visualizer.
tf.summary.image('images', images) tf.summary.image('images', images)
return images, labels
return images, tf.reshape(label_batch, [batch_size])
class DataPreprocessor(object):
"""Applies transformations to dataset record."""
def __init__(self, distords):
self._distords = distords
def __call__(self, record):
"""Process img for training or eval."""
img = record['image']
img = tf.cast(img, tf.float32)
if self._distords: # training
# Randomly crop a [height, width] section of the image.
img = tf.random_crop(img, [IMAGE_SIZE, IMAGE_SIZE, 3])
# Randomly flip the image horizontally.
img = tf.image.random_flip_left_right(img)
# Because these operations are not commutative, consider randomizing
# the order their operation.
# NOTE: since per_image_standardization zeros the mean and makes
# the stddev unit, this likely has no effect see tensorflow#1458.
img = tf.image.random_brightness(img, max_delta=63)
img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
else: # Image processing for evaluation.
# Crop the central [height, width] of the image.
img = tf.image.resize_image_with_crop_or_pad(img, IMAGE_SIZE, IMAGE_SIZE)
# Subtract off the mean and divide by the variance of the pixels.
img = tf.image.per_image_standardization(img)
return dict(input=img, target=record['label'])
def distorted_inputs(data_dir, batch_size): def distorted_inputs(batch_size):
"""Construct distorted input for CIFAR training using the Reader ops. """Construct distorted input for CIFAR training using the Reader ops.
Args: Args:
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch. batch_size: Number of images per batch.
Returns: Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
""" """
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) return _get_images_labels(batch_size, tfds.Split.TRAIN, distords=True)
for i in xrange(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
with tf.name_scope('data_augmentation'):
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for training the network. Note the many random
# distortions applied to the image.
# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Because these operations are not commutative, consider randomizing
# the order their operation.
# NOTE: since per_image_standardization zeros the mean and makes
# the stddev unit, this likely has no effect see tensorflow#1458.
distorted_image = tf.image.random_brightness(distorted_image,
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples. def inputs(eval_data, batch_size):
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=True)
def inputs(eval_data, data_dir, batch_size):
"""Construct input for CIFAR evaluation using the Reader ops. """Construct input for CIFAR evaluation using the Reader ops.
Args: Args:
eval_data: bool, indicating if one should use the train or eval data set. eval_data: bool, indicating if one should use the train or eval data set.
data_dir: Path to the CIFAR-10 data directory.
batch_size: Number of images per batch. batch_size: Number of images per batch.
Returns: Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
""" """
if not eval_data: split = tfds.Split.TEST if eval_data == 'test' else tfds.Split.TRAIN
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) return _get_images_labels(batch_size, split)
for i in xrange(1, 6)]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
else:
filenames = [os.path.join(data_dir, 'test_batch.bin')]
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
with tf.name_scope('input'):
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for evaluation.
# Crop the central [height, width] of the image.
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
height, width)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(resized_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch *
min_fraction_of_examples_in_queue)
# Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size,
shuffle=False)
...@@ -116,7 +116,6 @@ def train(): ...@@ -116,7 +116,6 @@ def train():
def main(argv=None): # pylint: disable=unused-argument def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir): if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment