cifar10_main.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the CIFAR-10 dataset."""
16
17
18
19
20
21
22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

23
from absl import logging
24
25
from absl import app as absl_app
from absl import flags
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
26
from six.moves import range
Karmel Allison's avatar
Karmel Allison committed
27
import tensorflow as tf  # pylint: disable=g-bad-import-order
28

29
30
from official.r1.resnet import resnet_model
from official.r1.resnet import resnet_run_loop
31
32
from official.utils.flags import core as flags_core
from official.utils.logs import logger
33

34
35
36
37
HEIGHT = 32
WIDTH = 32
NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
38
39
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
40
NUM_CLASSES = 10
41
42
_NUM_DATA_FILES = 5

43
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
Shining Sun's avatar
Shining Sun committed
44
NUM_IMAGES = {
45
46
47
    'train': 50000,
    'validation': 10000,
}
48

49
50
DATASET_NAME = 'CIFAR-10'

51

52
53
54
###############################################################################
# Data processing
###############################################################################
55
def get_filenames(is_training, data_dir):
56
  """Returns a list of filenames."""
57
  assert tf.io.gfile.exists(data_dir), (
58
59
      'Run cifar10_download_and_extract.py first to download and extract the '
      'CIFAR-10 data.')
60

61
  if is_training:
62
63
    return [
        os.path.join(data_dir, 'data_batch_%d.bin' % i)
64
        for i in range(1, _NUM_DATA_FILES + 1)
65
66
    ]
  else:
67
    return [os.path.join(data_dir, 'test_batch.bin')]
68
69


70
def parse_record(raw_record, is_training, dtype):
Kathy Wu's avatar
Kathy Wu committed
71
  """Parse CIFAR-10 image and label from a raw record."""
72
  # Convert bytes to a vector of uint8 that is record_bytes long.
73
  record_vector = tf.io.decode_raw(raw_record, tf.uint8)
74

75
76
  # The first byte represents the label, which we convert from uint8 to int32
  # and then to one-hot.
77
  label = tf.cast(record_vector[0], tf.int32)
78
79
80

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
81
  depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
82
                           [NUM_CHANNELS, HEIGHT, WIDTH])
83
84
85

  # Convert from [depth, height, width] to [height, width, depth], and cast as
  # float32.
86
  image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
87

88
  image = preprocess_image(image, is_training)
89
  image = tf.cast(image, dtype)
90

91
  return image, label
92
93


94
95
96
97
def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
98
    image = tf.image.resize_with_crop_or_pad(
99
        image, HEIGHT + 8, WIDTH + 8)
100

101
    # Randomly crop a [HEIGHT, WIDTH] section of the image.
102
    image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
Kathy Wu's avatar
Kathy Wu committed
103

104
105
    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)
Kathy Wu's avatar
Kathy Wu committed
106
107
108

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
109
  return image
Kathy Wu's avatar
Kathy Wu committed
110
111


112
113
114
115
116
117
118
def input_fn(is_training,
             data_dir,
             batch_size,
             num_epochs=1,
             dtype=tf.float32,
             datasets_num_private_threads=None,
             parse_record_fn=parse_record,
119
120
             input_context=None,
             drop_remainder=False):
121
  """Input function which provides batches for train or eval.
122
123

  Args:
124
    is_training: A boolean denoting whether the input is for training.
Kathy Wu's avatar
Kathy Wu committed
125
    data_dir: The directory containing the input data.
126
    batch_size: The number of samples per batch.
127
    num_epochs: The number of epochs to repeat the dataset.
128
    dtype: Data type to use for images/features
129
    datasets_num_private_threads: Number of private threads for tf.data.
Priya Gupta's avatar
Priya Gupta committed
130
    parse_record_fn: Function to use for parsing the records.
131
132
    input_context: A `tf.distribute.InputContext` object passed in by
      `tf.distribute.Strategy`.
133
134
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
135
136

  Returns:
137
    A dataset that can be used for iteration.
138
  """
139
140
  filenames = get_filenames(is_training, data_dir)
  dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
141

142
  if input_context:
143
    logging.info(
144
145
146
147
148
        'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d' % (
            input_context.input_pipeline_id, input_context.num_input_pipelines))
    dataset = dataset.shard(input_context.num_input_pipelines,
                            input_context.input_pipeline_id)

Karmel Allison's avatar
Karmel Allison committed
149
  return resnet_run_loop.process_record_dataset(
Taylor Robie's avatar
Taylor Robie committed
150
151
152
      dataset=dataset,
      is_training=is_training,
      batch_size=batch_size,
Shining Sun's avatar
Shining Sun committed
153
      shuffle_buffer=NUM_IMAGES['train'],
Priya Gupta's avatar
Priya Gupta committed
154
      parse_record_fn=parse_record_fn,
Taylor Robie's avatar
Taylor Robie committed
155
      num_epochs=num_epochs,
156
      dtype=dtype,
157
158
      datasets_num_private_threads=datasets_num_private_threads,
      drop_remainder=drop_remainder
159
  )
160
161


Toby Boyd's avatar
Toby Boyd committed
162
def get_synth_input_fn(dtype):
Karmel Allison's avatar
Karmel Allison committed
163
  return resnet_run_loop.get_synth_input_fn(
164
      HEIGHT, WIDTH, NUM_CHANNELS, NUM_CLASSES, dtype=dtype)
165
166


167
168
169
###############################################################################
# Running the model
###############################################################################
170
class Cifar10Model(resnet_model.Model):
Karmel Allison's avatar
Karmel Allison committed
171
  """Model class with appropriate defaults for CIFAR-10 data."""
172

173
  def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
174
               resnet_version=resnet_model.DEFAULT_VERSION,
175
               dtype=resnet_model.DEFAULT_DTYPE):
Neal Wu's avatar
Neal Wu committed
176
177
178
179
180
181
182
    """These are the parameters that work for CIFAR-10 data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
183
        enables users to extend the same model to their own datasets.
184
185
      resnet_version: Integer representing which version of the ResNet network
      to use. See README for details. Valid values: [1, 2]
186
      dtype: The TensorFlow dtype to use for calculations.
Karmel Allison's avatar
Karmel Allison committed
187
188
189

    Raises:
      ValueError: if invalid resnet_size is chosen
Neal Wu's avatar
Neal Wu committed
190
    """
191
192
193
194
195
196
197
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    num_blocks = (resnet_size - 2) // 6

    super(Cifar10Model, self).__init__(
        resnet_size=resnet_size,
198
        bottleneck=False,
199
        num_classes=num_classes,
200
201
202
203
204
205
206
        num_filters=16,
        kernel_size=3,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        block_sizes=[num_blocks] * 3,
        block_strides=[1, 2, 2],
207
        resnet_version=resnet_version,
208
209
210
        data_format=data_format,
        dtype=dtype
    )
211
212


213
214
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
215
  features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
216
  # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
217
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
218
219
220
      batch_size=params['batch_size'] * params.get('num_workers', 1),
      batch_denom=128, num_images=NUM_IMAGES['train'],
      boundary_epochs=[91, 136, 182], decay_rates=[1, 0.1, 0.01, 0.001])
221

222
223
  # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
  # and seems more stable in testing. The difference was nominal for ResNet-56.
224
225
226
227
228
229
230
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
Karmel Allison's avatar
Karmel Allison committed
231
  def loss_filter_fn(_):
232
233
    return True

234
235
236
237
238
239
240
241
242
243
  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Cifar10Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
244
      resnet_version=params['resnet_version'],
245
246
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
Zac Wellmer's avatar
Zac Wellmer committed
247
248
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
249
  )
250
251


252
253
254
def define_cifar_flags():
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)
255
  flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
256
                          model_dir='/tmp/cifar10_model',
257
258
                          resnet_size='56',
                          train_epochs=182,
259
                          epochs_between_evals=10,
260
261
                          batch_size=128,
                          image_bytes_as_serving_input=False)
262

263

264
265
266
267
268
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
269
270
271

  Returns:
    Dictionary of results. Including final accuracy.
272
  """
273
  if flags_obj.image_bytes_as_serving_input:
274
    logging.fatal(
275
276
        '--image_bytes_as_serving_input cannot be set to True for CIFAR. '
        'This flag is only applicable to ImageNet.')
277
278
    return

Toby Boyd's avatar
Toby Boyd committed
279
280
281
  input_function = (flags_obj.use_synthetic_data and
                    get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
                    input_fn)
282
  result = resnet_run_loop.resnet_main(
283
      flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
284
      shape=[HEIGHT, WIDTH, NUM_CHANNELS])
285

286
287
  return result

288

289
def main(_):
290
291
  with logger.benchmark_context(flags.FLAGS):
    run_cifar(flags.FLAGS)
292
293


294
if __name__ == '__main__':
295
  logging.set_verbosity(logging.INFO)
296
297
  define_cifar_flags()
  absl_app.run(main)