Commit 8008e72f authored by Toby Boyd's avatar Toby Boyd
Browse files

Generate data downloads and creates files

parent d8588a7e
import collections
import six import six
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import basic_session_run_hooks
......
...@@ -26,8 +26,18 @@ import argparse ...@@ -26,8 +26,18 @@ import argparse
import cPickle import cPickle
import os import os
import tarfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR_10_FILE_NAME = 'cifar-10-python.tar.gz'
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
def download_and_extract(data_dir):
# download CIFAR-10 if not already downloaded.
tf.contrib.learn.datasets.base.maybe_download(CIFAR_10_FILE_NAME, data_dir, DATA_URL)
tarfile.open(os.path.join(data_dir,CIFAR_10_FILE_NAME), 'r:gz').extractall(data_dir)
def _int64_feature(value): def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
...@@ -57,6 +67,7 @@ def convert_to_tfrecord(input_files, output_file): ...@@ -57,6 +67,7 @@ def convert_to_tfrecord(input_files, output_file):
print('Generating %s' % output_file) print('Generating %s' % output_file)
with tf.python_io.TFRecordWriter(output_file) as record_writer: with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files: for input_file in input_files:
print(input_file)
data_dict = read_pickle_from_file(input_file) data_dict = read_pickle_from_file(input_file)
data = data_dict['data'] data = data_dict['data']
labels = data_dict['labels'] labels = data_dict['labels']
...@@ -71,12 +82,18 @@ def convert_to_tfrecord(input_files, output_file): ...@@ -71,12 +82,18 @@ def convert_to_tfrecord(input_files, output_file):
record_writer.write(example.SerializeToString()) record_writer.write(example.SerializeToString())
def main(input_dir, output_dir): def main(data_dir):
download_and_extract(data_dir)
file_names = _get_file_names() file_names = _get_file_names()
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
for mode, files in file_names.items(): for mode, files in file_names.items():
input_files = [ input_files = [
os.path.join(input_dir, f) for f in files] os.path.join(input_dir, f) for f in files]
output_file = os.path.join(output_dir, mode + '.tfrecords') output_file = os.path.join(data_dir, mode + '.tfrecords')
try:
os.remove(output_file)
except OSError:
pass
# Convert to Examples and write the result to TFRecords. # Convert to Examples and write the result to TFRecords.
convert_to_tfrecord(input_files, output_file) convert_to_tfrecord(input_files, output_file)
print('Done!') print('Done!')
...@@ -85,19 +102,11 @@ def main(input_dir, output_dir): ...@@ -85,19 +102,11 @@ def main(input_dir, output_dir):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--input-dir', '--data-dir',
type=str,
default='',
help='Directory where CIFAR10 data is located.'
)
parser.add_argument(
'--output-dir',
type=str, type=str,
default='', default='',
help="""\ help='Directory to download and extract CIFAR-10 to.'
Directory where TFRecords will be saved.The TFRecords will have the same
name as the CIFAR10 inputs + .tfrecords.\
"""
) )
args = parser.parse_args() args = parser.parse_args()
main(args.input_dir, args.output_dir) main(args.data_dir)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment