Commit 7269b862 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Updating get_filenames to only use a single file

parent 32533a06
......@@ -39,15 +39,8 @@ class Cifar10DataSet(object):
self.use_distortion = use_distortion
def get_filenames(self):
if self.subset == 'train':
return [
os.path.join(self.data_dir, 'data_batch_%d.tfrecords' % i)
for i in xrange(1, 5)
]
elif self.subset == 'validation':
return [os.path.join(self.data_dir, 'data_batch_5.tfrecords')]
elif self.subset == 'eval':
return [os.path.join(self.data_dir, 'test_batch.tfrecords')]
if self.subset in ['train', 'validation', 'eval']:
return [os.path.join(self.data_dir, self.subset + '.tfrecords')]
else:
raise ValueError('Invalid data subset "%s"' % self.subset)
......@@ -66,7 +59,9 @@ class Cifar10DataSet(object):
image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width].
image = tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0])
image = tf.cast(
tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
tf.float32)
label = tf.cast(features['label'], tf.int32)
# Custom preprocessing .
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment