Unverified Commit 10805b06 authored by Jon Shlens's avatar Jon Shlens Committed by GitHub
Browse files

Merge pull request #3153 from StevenHickson/master

Updated cifar10_input.py to make tensorboard graph more meaningful
parents 17f7d552 435ee4da
...@@ -157,44 +157,45 @@ def distorted_inputs(data_dir, batch_size): ...@@ -157,44 +157,45 @@ def distorted_inputs(data_dir, batch_size):
# Create a queue that produces the filenames to read. # Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames) filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue. with tf.name_scope('data_augmentation'):
read_input = read_cifar10(filename_queue) # Read examples from files in the filename queue.
reshaped_image = tf.cast(read_input.uint8image, tf.float32) read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for training the network. Note the many random
# distortions applied to the image. # Image processing for training the network. Note the many random
# distortions applied to the image.
# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) # Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image) # Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Because these operations are not commutative, consider randomizing
# the order their operation. # Because these operations are not commutative, consider randomizing
# NOTE: since per_image_standardization zeros the mean and makes # the order their operation.
# the stddev unit, this likely has no effect see tensorflow#1458. # NOTE: since per_image_standardization zeros the mean and makes
distorted_image = tf.image.random_brightness(distorted_image, # the stddev unit, this likely has no effect see tensorflow#1458.
max_delta=63) distorted_image = tf.image.random_brightness(distorted_image,
distorted_image = tf.image.random_contrast(distorted_image, max_delta=63)
lower=0.2, upper=1.8) distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image) # Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3]) # Set the shapes of tensors.
read_input.label.set_shape([1]) float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4 # Ensure that the random shuffling has good mixing properties.
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue = 0.4
min_fraction_of_examples_in_queue) min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
print ('Filling queue with %d CIFAR images before starting to train. ' min_fraction_of_examples_in_queue)
'This will take a few minutes.' % min_queue_examples) print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples. # Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label, return _generate_image_and_label_batch(float_image, read_input.label,
...@@ -226,32 +227,33 @@ def inputs(eval_data, data_dir, batch_size): ...@@ -226,32 +227,33 @@ def inputs(eval_data, data_dir, batch_size):
if not tf.gfile.Exists(f): if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read. with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(filenames) # Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue. # Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue) read_input = read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32) reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE height = IMAGE_SIZE
width = IMAGE_SIZE width = IMAGE_SIZE
# Image processing for evaluation. # Image processing for evaluation.
# Crop the central [height, width] of the image. # Crop the central [height, width] of the image.
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
height, width) height, width)
# Subtract off the mean and divide by the variance of the pixels. # Subtract off the mean and divide by the variance of the pixels.
float_image = tf.image.per_image_standardization(resized_image) float_image = tf.image.per_image_standardization(resized_image)
# Set the shapes of tensors. # Set the shapes of tensors.
float_image.set_shape([height, width, 3]) float_image.set_shape([height, width, 3])
read_input.label.set_shape([1]) read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties. # Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4 min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch * min_queue_examples = int(num_examples_per_epoch *
min_fraction_of_examples_in_queue) min_fraction_of_examples_in_queue)
# Generate a batch of images and labels by building up a queue of examples. # Generate a batch of images and labels by building up a queue of examples.
return _generate_image_and_label_batch(float_image, read_input.label, return _generate_image_and_label_batch(float_image, read_input.label,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment