Commit 1f6b3d7d authored by Kathy Wu's avatar Kathy Wu
Browse files

Changing tf.contrib.data to tf.data for release of tf 1.4

parent 4cfa0d3b
......@@ -53,7 +53,7 @@ _NUM_IMAGES = {
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the contrib.data input pipeline."""
"""A simple input_fn using the tf.data input pipeline."""
def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
......@@ -71,8 +71,12 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
dataset = tf.contrib.data.TFRecordDataset([filename])
dataset = tf.data.TFRecordDataset([filename])
# Parse each example in the dataset
dataset = dataset.map(example_parser)
# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
......
......@@ -77,7 +77,7 @@ _SHUFFLE_BUFFER = 20000
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = _HEIGHT * _WIDTH * _DEPTH + 1
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training, data_dir):
......@@ -138,7 +138,7 @@ def train_preprocess_fn(image, label):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
......@@ -148,13 +148,11 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
dataset = dataset.map(dataset_parser, num_threads=1,
output_buffer_size=2 * batch_size)
dataset = dataset.map(dataset_parser)
# For training, preprocess the image and shuffle.
if is_training:
dataset = dataset.map(train_preprocess_fn, num_threads=1,
output_buffer_size=2 * batch_size)
dataset = dataset.map(train_preprocess_fn)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
......@@ -162,9 +160,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# Subtract off the mean and divide by the variance of the pixels.
dataset = dataset.map(
lambda image, label: (tf.image.per_image_standardization(image), label),
num_threads=1,
output_buffer_size=2 * batch_size)
lambda image, label: (tf.image.per_image_standardization(image), label))
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
......
......@@ -134,17 +134,16 @@ def dataset_parser(value, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input function which provides batches for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(
dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))
if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=5,
output_buffer_size=batch_size)
num_parallel_calls=5)
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
......
......@@ -178,8 +178,8 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5)
dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_parallel_calls=5)
if shuffle:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
......@@ -193,7 +193,6 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
features, labels = iterator.get_next()
return features, labels
def main(unused_argv):
# Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment