Commit c3b69841 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2080 from mari-linhares/patch-5

Refactoring and making just one tf.record file for training for efficiency
parents 751bfb3f 6fe50699
...@@ -39,15 +39,8 @@ class Cifar10DataSet(object): ...@@ -39,15 +39,8 @@ class Cifar10DataSet(object):
self.use_distortion = use_distortion self.use_distortion = use_distortion
def get_filenames(self): def get_filenames(self):
if self.subset == 'train': if self.subset in ['train', 'validation', 'eval']:
return [ return [os.path.join(self.data_dir, self.subset + '.tfrecords')]
os.path.join(self.data_dir, 'data_batch_%d.tfrecords' % i)
for i in xrange(1, 5)
]
elif self.subset == 'validation':
return [os.path.join(self.data_dir, 'data_batch_5.tfrecords')]
elif self.subset == 'eval':
return [os.path.join(self.data_dir, 'test_batch.tfrecords')]
else: else:
raise ValueError('Invalid data subset "%s"' % self.subset) raise ValueError('Invalid data subset "%s"' % self.subset)
...@@ -66,7 +59,9 @@ class Cifar10DataSet(object): ...@@ -66,7 +59,9 @@ class Cifar10DataSet(object):
image.set_shape([DEPTH * HEIGHT * WIDTH]) image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width]. # Reshape from [depth * height * width] to [depth, height, width].
image = tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]) image = tf.cast(
tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
tf.float32)
label = tf.cast(features['label'], tf.int32) label = tf.cast(features['label'], tf.int32)
# Custom preprocessing . # Custom preprocessing .
......
...@@ -46,9 +46,12 @@ def _bytes_feature(value): ...@@ -46,9 +46,12 @@ def _bytes_feature(value):
def _get_file_names(): def _get_file_names():
"""Returns the file names expected to exist in the input_dir.""" """Returns the file names expected to exist for training, validation
file_names = ['data_batch_%d' % i for i in xrange(1, 6)] and evaluation in the input_dir."""
file_names.append('test_batch') file_names = {}
file_names['train'] = ['data_batch_%d' % i for i in xrange(1, 5)]
file_names['validation'] = ['data_batch_5']
file_names['eval'] = ['test_batch']
return file_names return file_names
...@@ -58,18 +61,12 @@ def read_pickle_from_file(filename): ...@@ -58,18 +61,12 @@ def read_pickle_from_file(filename):
return data_dict return data_dict
def main(argv): def convert_to_tfrecord(input_files, output_file):
del argv # Unused. """Converts a file to tfrecords."""
print('Generating %s' % output_file)
file_names = _get_file_names() record_writer = tf.python_io.TFRecordWriter(output_file)
for file_name in file_names:
input_file = os.path.join(FLAGS.input_dir, file_name)
output_file = os.path.join(FLAGS.output_dir, file_name + '.tfrecords')
print('Generating %s' % output_file)
record_writer = tf.python_io.TFRecordWriter(output_file)
for input_file in input_files:
data_dict = read_pickle_from_file(input_file) data_dict = read_pickle_from_file(input_file)
data = data_dict['data'] data = data_dict['data']
labels = data_dict['labels'] labels = data_dict['labels']
...@@ -82,8 +79,18 @@ def main(argv): ...@@ -82,8 +79,18 @@ def main(argv):
'label': _int64_feature(labels[i]) 'label': _int64_feature(labels[i])
})) }))
record_writer.write(example.SerializeToString()) record_writer.write(example.SerializeToString())
record_writer.close() record_writer.close()
def main(argv):
del argv # Unused.
file_names = _get_file_names()
for mode, files in file_names.items():
input_files = [
os.path.join(FLAGS.input_dir, f) for f in files]
output_file = os.path.join(FLAGS.output_dir, mode + '.tfrecords')
# Convert to Examples and write the result to TFRecords.
convert_to_tfrecord(input_files, output_file)
print('Done!') print('Done!')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment