"vscode:/vscode.git/clone" did not exist on "97905301cb3b9109b0503d1bd9339d10589e9b85"
Commit 73def645 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Use FixedLengthRecordDatatest

- Prior to this change, the use of tf.data.Dataset essentially embedded
  the entire training/evaluation dataset into the graph as a constant,
  leading to unnecessarily humungous graphs (Fixes #3017)
- Also, use batching on the evaluation dataset to allow
  evaluation on GPUs that cannot fit the entire evaluation dataset in
  memory (Fixes #3046)
parent a3669a93
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import gzip
import numpy as np
import tensorflow as tf
def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_images = read32(f)
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_items = read32(f)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
def maybe_download(directory, filename):
"""Download a file from the MNIST dataset, if it doesn't already exist."""
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filename = filename + '.gz'
zipped_filepath = os.path.join(directory, zipped_filename)
tf.contrib.learn.datasets.base.maybe_download(zipped_filename, directory, url)
with gzip.open(os.path.join(zipped_filepath), 'rb') as f_in, open(
filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
images_file = maybe_download(directory, images_file)
labels_file = maybe_download(directory, labels_file)
check_image_file_header(images_file)
check_labels_file_header(labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def one_hot_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
label = tf.reshape(label, []) # label is a scalar
return tf.one_hot(label, 10)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(one_hot_label)
return tf.data.Dataset.zip((images, labels))
def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
......@@ -22,19 +22,7 @@ import os
import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
def eval_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images, data.labels))
import dataset
class Model(object):
......@@ -151,10 +139,10 @@ def main(unused_argv):
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
(images, labels) = ds.make_one_shot_iterator().get_next()
return (images, labels)
# Set up training hook that logs the training accuracy every 100 steps.
......@@ -165,7 +153,8 @@ def main(unused_argv):
# Evaluate the model and print results
def eval_input_fn():
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next()
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment