Commit 8008e72f authored by Toby Boyd's avatar Toby Boyd
Browse files

Generate data downloads and creates files

parent d8588a7e
import collections
import six
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.training import basic_session_run_hooks
......
......@@ -26,8 +26,18 @@ import argparse
import cPickle
import os
import tarfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR_10_FILE_NAME = 'cifar-10-python.tar.gz'
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
def download_and_extract(data_dir):
# download CIFAR-10 if not already downloaded.
tf.contrib.learn.datasets.base.maybe_download(CIFAR_10_FILE_NAME, data_dir, DATA_URL)
tarfile.open(os.path.join(data_dir,CIFAR_10_FILE_NAME), 'r:gz').extractall(data_dir)
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
......@@ -57,6 +67,7 @@ def convert_to_tfrecord(input_files, output_file):
print('Generating %s' % output_file)
with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
print(input_file)
data_dict = read_pickle_from_file(input_file)
data = data_dict['data']
labels = data_dict['labels']
......@@ -71,12 +82,18 @@ def convert_to_tfrecord(input_files, output_file):
record_writer.write(example.SerializeToString())
def main(input_dir, output_dir):
def main(data_dir):
download_and_extract(data_dir)
file_names = _get_file_names()
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
for mode, files in file_names.items():
input_files = [
os.path.join(input_dir, f) for f in files]
output_file = os.path.join(output_dir, mode + '.tfrecords')
output_file = os.path.join(data_dir, mode + '.tfrecords')
try:
os.remove(output_file)
except OSError:
pass
# Convert to Examples and write the result to TFRecords.
convert_to_tfrecord(input_files, output_file)
print('Done!')
......@@ -85,19 +102,11 @@ def main(input_dir, output_dir):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--input-dir',
type=str,
default='',
help='Directory where CIFAR10 data is located.'
)
parser.add_argument(
'--output-dir',
'--data-dir',
type=str,
default='',
help="""\
Directory where TFRecords will be saved.The TFRecords will have the same
name as the CIFAR10 inputs + .tfrecords.\
"""
help='Directory to download and extract CIFAR-10 to.'
)
args = parser.parse_args()
main(args.input_dir, args.output_dir)
main(args.data_dir)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment