cifar10_main.py 7.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the CIFAR-10 dataset."""
16
17
18
19
20
21

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
22
import sys
23
24
25

import tensorflow as tf

Karmel Allison's avatar
Karmel Allison committed
26
import resnet
27

28
29
_HEIGHT = 32
_WIDTH = 32
30
31
_NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS
32
33
34
_NUM_CLASSES = 10
_NUM_DATA_FILES = 5

35
36
37
38
_NUM_IMAGES = {
    'train': 50000,
    'validation': 10000,
}
39
40


41
42
43
###############################################################################
# Data processing
###############################################################################
44
45
def record_dataset(filenames):
  """Returns an input pipeline Dataset from `filenames`."""
46
  record_bytes = _DEFAULT_IMAGE_BYTES + 1
47
  return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
48
49


50
def get_filenames(is_training, data_dir):
51
  """Returns a list of filenames."""
52
  data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
53

54
55
56
  assert os.path.exists(data_dir), (
      'Run cifar10_download_and_extract.py first to download and extract the '
      'CIFAR-10 data.')
57

58
  if is_training:
59
60
    return [
        os.path.join(data_dir, 'data_batch_%d.bin' % i)
61
        for i in range(1, _NUM_DATA_FILES + 1)
62
63
    ]
  else:
64
    return [os.path.join(data_dir, 'test_batch.bin')]
65
66


Kathy Wu's avatar
Kathy Wu committed
67
68
def parse_record(raw_record):
  """Parse CIFAR-10 image and label from a raw record."""
69
70
71
  # Every record consists of a label followed by the image, with a fixed number
  # of bytes for each.
  label_bytes = 1
72
  record_bytes = label_bytes + _DEFAULT_IMAGE_BYTES
73

74
75
  # Convert bytes to a vector of uint8 that is record_bytes long.
  record_vector = tf.decode_raw(raw_record, tf.uint8)
76

77
78
  # The first byte represents the label, which we convert from uint8 to int32
  # and then to one-hot.
79
  label = tf.cast(record_vector[0], tf.int32)
80
  label = tf.one_hot(label, _NUM_CLASSES)
81
82
83

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
84
85
  depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
                           [_NUM_CHANNELS, _HEIGHT, _WIDTH])
86
87
88
89
90

  # Convert from [depth, height, width] to [height, width, depth], and cast as
  # float32.
  image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

91
  return image, label
92
93


94
95
96
97
def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
Neal Wu's avatar
Neal Wu committed
98
99
    image = tf.image.resize_image_with_crop_or_pad(
        image, _HEIGHT + 8, _WIDTH + 8)
100

101
    # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
102
    image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS])
Kathy Wu's avatar
Kathy Wu committed
103

104
105
    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)
Kathy Wu's avatar
Kathy Wu committed
106
107
108

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
109
  return image
Kathy Wu's avatar
Kathy Wu committed
110
111


112
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
113
  """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
114
115

  Args:
116
    is_training: A boolean denoting whether the input is for training.
Kathy Wu's avatar
Kathy Wu committed
117
    data_dir: The directory containing the input data.
118
    batch_size: The number of samples per batch.
119
    num_epochs: The number of epochs to repeat the dataset.
120
121
122

  Returns:
    A tuple of images and labels.
123
  """
124
  dataset = record_dataset(get_filenames(is_training, data_dir))
125

126
  if is_training:
127
    # When choosing shuffle buffer sizes, larger sizes result in better
128
129
130
    # randomness, while smaller sizes have better performance. Because CIFAR-10
    # is a relatively small dataset, we choose to shuffle the full epoch.
    dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
131

132
  dataset = dataset.map(parse_record)
133
  dataset = dataset.map(
134
135
      lambda image, label: (preprocess_image(image, is_training), label))

136
  dataset = dataset.prefetch(2 * batch_size)
137

Neal Wu's avatar
Neal Wu committed
138
139
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
140
  dataset = dataset.repeat(num_epochs)
141
142
143

  # Batch results by up to batch_size, and then fetch the tuple from the
  # iterator.
Neal Wu's avatar
Neal Wu committed
144
145
  dataset = dataset.batch(batch_size)
  iterator = dataset.make_one_shot_iterator()
146
147
148
149
150
  images, labels = iterator.get_next()

  return images, labels


151
152
153
###############################################################################
# Running the model
###############################################################################
Karmel Allison's avatar
Karmel Allison committed
154
class Cifar10Model(resnet.Model):
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  def __init__(self, resnet_size, data_format=None):
    """These are the parameters that work for CIFAR-10 data.
    """
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    num_blocks = (resnet_size - 2) // 6

    super(Cifar10Model, self).__init__(
        resnet_size=resnet_size,
        num_classes=_NUM_CLASSES,
        num_filters=16,
        kernel_size=3,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        second_pool_size=8,
        second_pool_stride=1,
Karmel Allison's avatar
Karmel Allison committed
173
        block_fn=resnet.building_block,
174
175
176
177
        block_sizes=[num_blocks] * 3,
        block_strides=[1, 2, 2],
        final_size=64,
        data_format=data_format)
178
179


180
181
182
183
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

Karmel Allison's avatar
Karmel Allison committed
184
  learning_rate_fn = resnet.learning_rate_with_decay(
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
  def loss_filter_fn(name):
    return True

Karmel Allison's avatar
Karmel Allison committed
201
202
203
204
205
206
207
  return resnet.resnet_model_fn(features, labels, mode, Cifar10Model,
                                resnet_size=params['resnet_size'],
                                weight_decay=weight_decay,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=loss_filter_fn)
208
209
210


def main(unused_argv):
Karmel Allison's avatar
Karmel Allison committed
211
  resnet.resnet_main(FLAGS, cifar10_model_fn, input_fn)
212
213
214
215


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
216

Karmel Allison's avatar
Karmel Allison committed
217
  parser = resnet.ResnetArgParser()
218
219
220
221
222
223
224
225
  # Set defaults that are reasonable for this model.
  parser.set_defaults(data_dir='/tmp/cifar10_data',
                      model_dir='/tmp/cifar10_model',
                      resnet_size=32,
                      train_epochs=250,
                      epochs_per_eval=10,
                      batch_size=128)

226
227
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(argv=[sys.argv[0]] + unparsed)