"...resnet50_tensorflow.git" did not exist on "db2fab9906b8fcfe44948e6a300ff4ecc25853f1"
Commit 68218034 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Cleaning up

* Removing unused import
* Removing trailing spaces
* Using global variables instead of magic numbers
parent 34495bde
...@@ -16,10 +16,8 @@ ...@@ -16,10 +16,8 @@
See http://www.cs.toronto.edu/~kriz/cifar.html. See http://www.cs.toronto.edu/~kriz/cifar.html.
""" """
import cPickle
import os import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
...@@ -28,6 +26,7 @@ HEIGHT = 32 ...@@ -28,6 +26,7 @@ HEIGHT = 32
WIDTH = 32 WIDTH = 32
DEPTH = 3 DEPTH = 3
class Cifar10DataSet(object): class Cifar10DataSet(object):
"""Cifar10 data set. """Cifar10 data set.
...@@ -60,26 +59,24 @@ class Cifar10DataSet(object): ...@@ -60,26 +59,24 @@ class Cifar10DataSet(object):
features = tf.parse_single_example( features = tf.parse_single_example(
serialized_example, serialized_example,
features={ features={
"image": tf.FixedLenFeature([], tf.string), 'image': tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64),
}) })
image = tf.decode_raw(features["image"], tf.uint8) image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([3*32*32]) image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width]. # Reshape from [depth * height * width] to [depth, height, width].
image = tf.transpose(tf.reshape(image, [3, 32, 32]), [1, 2, 0]) image = tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0])
label = tf.cast(features["label"], tf.int32) label = tf.cast(features['label'], tf.int32)
# Custom preprocessing . # Custom preprocessing .
image = self.preprocess(image) image = self.preprocess(image)
print(image, label)
return image, label return image, label
def make_batch(self, batch_size): def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'.""" """Read the images and labels from 'filenames'."""
filenames = self.get_filenames() filenames = self.get_filenames()
record_bytes = (32 * 32 * 3) + 1
# Repeat infinitely. # Repeat infinitely.
dataset = tf.contrib.data.TFRecordDataset(filenames).repeat() dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
...@@ -94,11 +91,12 @@ class Cifar10DataSet(object): ...@@ -94,11 +91,12 @@ class Cifar10DataSet(object):
# Ensure that the capacity is sufficiently large to provide good random # Ensure that the capacity is sufficiently large to provide good random
# shuffling. # shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size) dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up. # Batch it up.
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next() image_batch, label_batch = iterator.get_next()
print(image_batch, label_batch)
return image_batch, label_batch return image_batch, label_batch
def preprocess(self, image): def preprocess(self, image):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment