cifar10_main.py 8.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the CIFAR-10 dataset."""
16
17
18
19
20
21
22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

23
24
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
25
import tensorflow as tf  # pylint: disable=g-bad-import-order
26

27
from official.utils.flags import core as flags_core
28
29
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
30

31
32
_HEIGHT = 32
_WIDTH = 32
33
34
_NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS
35
36
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
37
38
39
_NUM_CLASSES = 10
_NUM_DATA_FILES = 5

40
41
42
43
_NUM_IMAGES = {
    'train': 50000,
    'validation': 10000,
}
44
45


46
47
48
###############################################################################
# Data processing
###############################################################################
49
def get_filenames(is_training, data_dir):
50
  """Returns a list of filenames."""
51
  data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
52

53
54
55
  assert os.path.exists(data_dir), (
      'Run cifar10_download_and_extract.py first to download and extract the '
      'CIFAR-10 data.')
56

57
  if is_training:
58
59
    return [
        os.path.join(data_dir, 'data_batch_%d.bin' % i)
60
        for i in range(1, _NUM_DATA_FILES + 1)
61
62
    ]
  else:
63
    return [os.path.join(data_dir, 'test_batch.bin')]
64
65


66
def parse_record(raw_record, is_training):
Kathy Wu's avatar
Kathy Wu committed
67
  """Parse CIFAR-10 image and label from a raw record."""
68
69
  # Convert bytes to a vector of uint8 that is record_bytes long.
  record_vector = tf.decode_raw(raw_record, tf.uint8)
70

71
72
  # The first byte represents the label, which we convert from uint8 to int32
  # and then to one-hot.
73
  label = tf.cast(record_vector[0], tf.int32)
74
  label = tf.one_hot(label, _NUM_CLASSES)
75
76
77

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
78
  depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
79
                           [_NUM_CHANNELS, _HEIGHT, _WIDTH])
80
81
82
83
84

  # Convert from [depth, height, width] to [height, width, depth], and cast as
  # float32.
  image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

85
86
  image = preprocess_image(image, is_training)

87
  return image, label
88
89


90
91
92
93
def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
Neal Wu's avatar
Neal Wu committed
94
95
    image = tf.image.resize_image_with_crop_or_pad(
        image, _HEIGHT + 8, _WIDTH + 8)
96

97
    # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
98
    image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS])
Kathy Wu's avatar
Kathy Wu committed
99

100
101
    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)
Kathy Wu's avatar
Kathy Wu committed
102
103
104

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
105
  return image
Kathy Wu's avatar
Kathy Wu committed
106
107


108
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
109
  """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
110
111

  Args:
112
    is_training: A boolean denoting whether the input is for training.
Kathy Wu's avatar
Kathy Wu committed
113
    data_dir: The directory containing the input data.
114
    batch_size: The number of samples per batch.
115
    num_epochs: The number of epochs to repeat the dataset.
116
117

  Returns:
118
    A dataset that can be used for iteration.
119
  """
120
121
  filenames = get_filenames(is_training, data_dir)
  dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
122

Karmel Allison's avatar
Karmel Allison committed
123
124
  return resnet_run_loop.process_record_dataset(
      dataset, is_training, batch_size, _NUM_IMAGES['train'],
125
126
      parse_record, num_epochs,
  )
127
128


129
def get_synth_input_fn():
Karmel Allison's avatar
Karmel Allison committed
130
131
  return resnet_run_loop.get_synth_input_fn(
      _HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
132
133


134
135
136
###############################################################################
# Running the model
###############################################################################
137
class Cifar10Model(resnet_model.Model):
Karmel Allison's avatar
Karmel Allison committed
138
  """Model class with appropriate defaults for CIFAR-10 data."""
139

140
  def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
141
142
               version=resnet_model.DEFAULT_VERSION,
               dtype=resnet_model.DEFAULT_DTYPE):
Neal Wu's avatar
Neal Wu committed
143
144
145
146
147
148
149
    """These are the parameters that work for CIFAR-10 data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
150
        enables users to extend the same model to their own datasets.
151
152
      version: Integer representing which version of the ResNet network to use.
        See README for details. Valid values: [1, 2]
153
      dtype: The TensorFlow dtype to use for calculations.
Karmel Allison's avatar
Karmel Allison committed
154
155
156

    Raises:
      ValueError: if invalid resnet_size is chosen
Neal Wu's avatar
Neal Wu committed
157
    """
158
159
160
161
162
163
164
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    num_blocks = (resnet_size - 2) // 6

    super(Cifar10Model, self).__init__(
        resnet_size=resnet_size,
165
        bottleneck=False,
166
        num_classes=num_classes,
167
168
169
170
171
172
173
174
        num_filters=16,
        kernel_size=3,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        block_sizes=[num_blocks] * 3,
        block_strides=[1, 2, 2],
        final_size=64,
175
        version=version,
176
177
178
        data_format=data_format,
        dtype=dtype
    )
179
180


181
182
183
184
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
  features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])

185
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
186
187
188
189
190
191
192
193
194
195
196
197
198
      batch_size=params['batch_size'], batch_denom=128,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
      decay_rates=[1, 0.1, 0.01, 0.001])

  # We use a weight decay of 0.0002, which performs better
  # than the 0.0001 that was originally suggested.
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
Karmel Allison's avatar
Karmel Allison committed
199
  def loss_filter_fn(_):
200
201
    return True

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Cifar10Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      version=params['version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
      dtype=params['dtype']
  )
217
218


219
220
221
222
223
224
225
226
227
def define_cifar_flags():
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)
  flags_core.set_defaults(data_dir='/tmp/cifar10_data',
                          model_dir='/tmp/cifar10_model',
                          resnet_size='32',
                          train_epochs=250,
                          epochs_between_evals=10,
                          batch_size=128)
228

229

230
231
232
233
234
235
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
  """
236
237
  input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
                    or input_fn)
238
239

  resnet_run_loop.resnet_main(
240
      flags_obj, cifar10_model_fn, input_function,
241
      shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
242
243


244
245
246
247
def main(_):
  run_cifar(flags.FLAGS)


248
249
if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
250
251
  define_cifar_flags()
  absl_app.run(main)