Commit 1f6b3d7d authored by Kathy Wu's avatar Kathy Wu
Browse files

Changing tf.contrib.data to tf.data for release of tf 1.4

parent 4cfa0d3b
...@@ -53,7 +53,7 @@ _NUM_IMAGES = { ...@@ -53,7 +53,7 @@ _NUM_IMAGES = {
def input_fn(is_training, filename, batch_size=1, num_epochs=1): def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline.""" """A simple input_fn using the tf.data input pipeline."""
def example_parser(serialized_example): def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors.""" """Parses a single tf.Example into image and label tensors."""
...@@ -71,8 +71,12 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1): ...@@ -71,8 +71,12 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
label = tf.cast(features['label'], tf.int32) label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10) return image, tf.one_hot(label, 10)
dataset = tf.contrib.data.TFRecordDataset([filename]) dataset = tf.data.TFRecordDataset([filename])
# Parse each example in the dataset
dataset = dataset.map(example_parser)
# Apply dataset transformations
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is # randomness, while smaller sizes have better performance. Because MNIST is
......
...@@ -77,7 +77,7 @@ _SHUFFLE_BUFFER = 20000 ...@@ -77,7 +77,7 @@ _SHUFFLE_BUFFER = 20000
def record_dataset(filenames): def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`.""" """Returns an input pipeline Dataset from `filenames`."""
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1 record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes) return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training, data_dir): def get_filenames(is_training, data_dir):
...@@ -138,7 +138,7 @@ def train_preprocess_fn(image, label): ...@@ -138,7 +138,7 @@ def train_preprocess_fn(image, label):
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
...@@ -148,13 +148,11 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -148,13 +148,11 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
A tuple of images and labels. A tuple of images and labels.
""" """
dataset = record_dataset(get_filenames(is_training, data_dir)) dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1, dataset = dataset.map(dataset_parser)
output_buffer_size=2 * batch_size)
# For training, preprocess the image and shuffle. # For training, preprocess the image and shuffle.
if is_training: if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1, dataset = dataset.map(train_preprocess_fn)
output_buffer_size=2 * batch_size)
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. # randomness, while smaller sizes have better performance.
...@@ -162,9 +160,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -162,9 +160,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# Subtract off the mean and divide by the variance of the pixels. # Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map( dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label), lambda image, label: (tf.image.per_image_standardization(image), label))
num_threads=1,
output_buffer_size=2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate # We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together. # epochs from blending together.
......
...@@ -134,17 +134,16 @@ def dataset_parser(value, is_training): ...@@ -134,17 +134,16 @@ def dataset_parser(value, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval.""" """Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices( dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir)) filenames(is_training, data_dir))
if is_training: if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training), dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5, num_parallel_calls=5)
output_buffer_size=batch_size)
if is_training: if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
......
...@@ -178,8 +178,8 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -178,8 +178,8 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return features, tf.equal(labels, '>50K') return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API. # Extract lines from input files using the Dataset API.
dataset = tf.contrib.data.TextLineDataset(data_file) dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5) dataset = dataset.map(parse_csv, num_parallel_calls=5)
if shuffle: if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
...@@ -193,7 +193,6 @@ def input_fn(data_file, num_epochs, shuffle, batch_size): ...@@ -193,7 +193,6 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
features, labels = iterator.get_next() features, labels = iterator.get_next()
return features, labels return features, labels
def main(unused_argv): def main(unused_argv):
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment