cifar10_main.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the CIFAR-10 dataset."""
16
17
18
19
20
21
22

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

23
24
from absl import app as absl_app
from absl import flags
25
from absl import logging
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
26
from six.moves import range
27
import tensorflow as tf
28

29
30
from official.r1.resnet import resnet_model
from official.r1.resnet import resnet_run_loop
31
from official.r1.utils.logs import logger
32
from official.utils.flags import core as flags_core
33

34
35
36
37
HEIGHT = 32
WIDTH = 32
NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
38
39
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
40
NUM_CLASSES = 10
41
42
_NUM_DATA_FILES = 5

43
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
Shining Sun's avatar
Shining Sun committed
44
NUM_IMAGES = {
45
46
47
    'train': 50000,
    'validation': 10000,
}
48

49
50
DATASET_NAME = 'CIFAR-10'

51

52
53
54
###############################################################################
# Data processing
###############################################################################
55
def get_filenames(is_training, data_dir):
56
  """Returns a list of filenames."""
57
  assert tf.io.gfile.exists(data_dir), (
58
59
      'Run cifar10_download_and_extract.py first to download and extract the '
      'CIFAR-10 data.')
60

61
  if is_training:
62
63
    return [
        os.path.join(data_dir, 'data_batch_%d.bin' % i)
64
        for i in range(1, _NUM_DATA_FILES + 1)
65
66
    ]
  else:
67
    return [os.path.join(data_dir, 'test_batch.bin')]
68
69


70
def parse_record(raw_record, is_training, dtype):
Kathy Wu's avatar
Kathy Wu committed
71
  """Parse CIFAR-10 image and label from a raw record."""
72
  # Convert bytes to a vector of uint8 that is record_bytes long.
73
  record_vector = tf.io.decode_raw(raw_record, tf.uint8)
74

75
76
  # The first byte represents the label, which we convert from uint8 to int32
  # and then to one-hot.
77
  label = tf.cast(record_vector[0], tf.int32)
78
79
80

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
81
  depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
82
                           [NUM_CHANNELS, HEIGHT, WIDTH])
83
84
85

  # Convert from [depth, height, width] to [height, width, depth], and cast as
  # float32.
86
  image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
87

88
  image = preprocess_image(image, is_training)
89
  image = tf.cast(image, dtype)
90

91
  return image, label
92
93


94
95
96
97
def preprocess_image(image, is_training):
  """Preprocess a single image of layout [height, width, depth]."""
  if is_training:
    # Resize the image to add four extra pixels on each side.
98
    image = tf.image.resize_with_crop_or_pad(
99
        image, HEIGHT + 8, WIDTH + 8)
100

101
    # Randomly crop a [HEIGHT, WIDTH] section of the image.
102
    image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
Kathy Wu's avatar
Kathy Wu committed
103

104
105
    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)
Kathy Wu's avatar
Kathy Wu committed
106
107
108

  # Subtract off the mean and divide by the variance of the pixels.
  image = tf.image.per_image_standardization(image)
109
  return image
Kathy Wu's avatar
Kathy Wu committed
110
111


112
113
114
115
116
117
118
def input_fn(is_training,
             data_dir,
             batch_size,
             num_epochs=1,
             dtype=tf.float32,
             datasets_num_private_threads=None,
             parse_record_fn=parse_record,
119
120
             input_context=None,
             drop_remainder=False):
121
  """Input function which provides batches for train or eval.
122
123

  Args:
124
    is_training: A boolean denoting whether the input is for training.
Kathy Wu's avatar
Kathy Wu committed
125
    data_dir: The directory containing the input data.
126
    batch_size: The number of samples per batch.
127
    num_epochs: The number of epochs to repeat the dataset.
128
    dtype: Data type to use for images/features
129
    datasets_num_private_threads: Number of private threads for tf.data.
Priya Gupta's avatar
Priya Gupta committed
130
    parse_record_fn: Function to use for parsing the records.
131
132
    input_context: A `tf.distribute.InputContext` object passed in by
      `tf.distribute.Strategy`.
133
134
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
135
136

  Returns:
137
    A dataset that can be used for iteration.
138
  """
139
140
  filenames = get_filenames(is_training, data_dir)
  dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
141

142
  if input_context:
143
    logging.info(
144
145
        'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
        input_context.input_pipeline_id, input_context.num_input_pipelines)
146
147
148
    dataset = dataset.shard(input_context.num_input_pipelines,
                            input_context.input_pipeline_id)

Karmel Allison's avatar
Karmel Allison committed
149
  return resnet_run_loop.process_record_dataset(
Taylor Robie's avatar
Taylor Robie committed
150
151
152
      dataset=dataset,
      is_training=is_training,
      batch_size=batch_size,
Shining Sun's avatar
Shining Sun committed
153
      shuffle_buffer=NUM_IMAGES['train'],
Priya Gupta's avatar
Priya Gupta committed
154
      parse_record_fn=parse_record_fn,
Taylor Robie's avatar
Taylor Robie committed
155
      num_epochs=num_epochs,
156
      dtype=dtype,
157
158
      datasets_num_private_threads=datasets_num_private_threads,
      drop_remainder=drop_remainder
159
  )
160
161


Toby Boyd's avatar
Toby Boyd committed
162
def get_synth_input_fn(dtype):
Karmel Allison's avatar
Karmel Allison committed
163
  return resnet_run_loop.get_synth_input_fn(
164
      HEIGHT, WIDTH, NUM_CHANNELS, NUM_CLASSES, dtype=dtype)
165
166


167
168
169
###############################################################################
# Running the model
###############################################################################
170
class Cifar10Model(resnet_model.Model):
Karmel Allison's avatar
Karmel Allison committed
171
  """Model class with appropriate defaults for CIFAR-10 data."""
172

173
  def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
174
               resnet_version=resnet_model.DEFAULT_VERSION,
175
               dtype=resnet_model.DEFAULT_DTYPE):
Neal Wu's avatar
Neal Wu committed
176
177
178
179
180
181
182
    """These are the parameters that work for CIFAR-10 data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
183
        enables users to extend the same model to their own datasets.
184
185
      resnet_version: Integer representing which version of the ResNet network
      to use. See README for details. Valid values: [1, 2]
186
      dtype: The TensorFlow dtype to use for calculations.
Karmel Allison's avatar
Karmel Allison committed
187
188
189

    Raises:
      ValueError: if invalid resnet_size is chosen
Neal Wu's avatar
Neal Wu committed
190
    """
191
192
193
194
195
196
197
    if resnet_size % 6 != 2:
      raise ValueError('resnet_size must be 6n + 2:', resnet_size)

    num_blocks = (resnet_size - 2) // 6

    super(Cifar10Model, self).__init__(
        resnet_size=resnet_size,
198
        bottleneck=False,
199
        num_classes=num_classes,
200
201
202
203
204
205
206
        num_filters=16,
        kernel_size=3,
        conv_stride=1,
        first_pool_size=None,
        first_pool_stride=None,
        block_sizes=[num_blocks] * 3,
        block_strides=[1, 2, 2],
207
        resnet_version=resnet_version,
208
209
210
        data_format=data_format,
        dtype=dtype
    )
211
212


213
214
def cifar10_model_fn(features, labels, mode, params):
  """Model function for CIFAR-10."""
215
  features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
216
  # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
217
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
218
219
220
      batch_size=params['batch_size'] * params.get('num_workers', 1),
      batch_denom=128, num_images=NUM_IMAGES['train'],
      boundary_epochs=[91, 136, 182], decay_rates=[1, 0.1, 0.01, 0.001])
221

222
223
  # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
  # and seems more stable in testing. The difference was nominal for ResNet-56.
224
225
226
227
228
229
230
  weight_decay = 2e-4

  # Empirical testing showed that including batch_normalization variables
  # in the calculation of regularized loss helped validation accuracy
  # for the CIFAR-10 dataset, perhaps because the regularization prevents
  # overfitting on the small data set. We therefore include all vars when
  # regularizing and computing loss during training.
Karmel Allison's avatar
Karmel Allison committed
231
  def loss_filter_fn(_):
232
233
    return True

234
235
236
237
238
239
240
241
242
243
  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=Cifar10Model,
      resnet_size=params['resnet_size'],
      weight_decay=weight_decay,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
244
      resnet_version=params['resnet_version'],
245
246
      loss_scale=params['loss_scale'],
      loss_filter_fn=loss_filter_fn,
Zac Wellmer's avatar
Zac Wellmer committed
247
248
      dtype=params['dtype'],
      fine_tune=params['fine_tune']
249
  )
250
251


252
253
254
def define_cifar_flags():
  resnet_run_loop.define_resnet_flags()
  flags.adopt_module_key_flags(resnet_run_loop)
255
  flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
256
                          model_dir='/tmp/cifar10_model',
257
258
                          resnet_size='56',
                          train_epochs=182,
259
                          epochs_between_evals=10,
260
261
                          batch_size=128,
                          image_bytes_as_serving_input=False)
262

263

264
265
266
267
268
def run_cifar(flags_obj):
  """Run ResNet CIFAR-10 training and eval loop.

  Args:
    flags_obj: An object containing parsed flag values.
269
270
271

  Returns:
    Dictionary of results. Including final accuracy.
272
  """
273
  if flags_obj.image_bytes_as_serving_input:
274
    logging.fatal(
275
276
        '--image_bytes_as_serving_input cannot be set to True for CIFAR. '
        'This flag is only applicable to ImageNet.')
277
278
    return

Toby Boyd's avatar
Toby Boyd committed
279
280
281
  input_function = (flags_obj.use_synthetic_data and
                    get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
                    input_fn)
282
  result = resnet_run_loop.resnet_main(
283
      flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
284
      shape=[HEIGHT, WIDTH, NUM_CHANNELS])
285

286
287
  return result

288

289
def main(_):
290
291
  with logger.benchmark_context(flags.FLAGS):
    run_cifar(flags.FLAGS)
292
293


294
if __name__ == '__main__':
295
  logging.set_verbosity(logging.INFO)
296
297
  define_cifar_flags()
  absl_app.run(main)