imagenet_main.py 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the ImageNet dataset."""
16
17
18
19
20
21

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
22
import sys
23

24
25
from absl import app as absl_app
from absl import flags
Karmel Allison's avatar
Karmel Allison committed
26
import tensorflow as tf  # pylint: disable=g-bad-import-order
27

28
from official.utils.flags import core as flags_core
29
from official.resnet import imagenet_preprocessing
30
31
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
32

33
_DEFAULT_IMAGE_SIZE = 224
34
_NUM_CHANNELS = 3
35
_NUM_CLASSES = 1001
36

37
38
39
40
_NUM_IMAGES = {
    'train': 1281167,
    'validation': 50000,
}
41

42
_NUM_TRAIN_FILES = 1024
43
_SHUFFLE_BUFFER = 1500
44

45

46
47
48
###############################################################################
# Data processing
###############################################################################
49
def get_filenames(is_training, data_dir):
50
51
52
  """Return filenames for dataset."""
  if is_training:
    return [
53
        os.path.join(data_dir, 'train-%05d-of-01024' % i)
54
        for i in range(_NUM_TRAIN_FILES)]
55
56
  else:
    return [
57
        os.path.join(data_dir, 'validation-%05d-of-00128' % i)
Neal Wu's avatar
Neal Wu committed
58
        for i in range(128)]
59
60


61
62
63
def _parse_example_proto(example_serialized):
  """Parses an Example proto containing a training example of an image.

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  The output of the build_image_data.py image preprocessing script is a dataset
  containing serialized Example protocol buffers. Each Example proto contains
  the following fields (values are included as examples):

    image/height: 462
    image/width: 581
    image/colorspace: 'RGB'
    image/channels: 3
    image/class/label: 615
    image/class/synset: 'n03623198'
    image/class/text: 'knee pad'
    image/object/bbox/xmin: 0.1
    image/object/bbox/xmax: 0.9
    image/object/bbox/ymin: 0.2
    image/object/bbox/ymax: 0.6
    image/object/bbox/label: 615
    image/format: 'JPEG'
    image/filename: 'ILSVRC2012_val_00041207.JPEG'
    image/encoded: <JPEG encoded string>
83
84
85
86
87
88
89

  Args:
    example_serialized: scalar Tensor tf.string containing a serialized
      Example protocol buffer.

  Returns:
    image_buffer: Tensor tf.string containing the contents of a JPEG file.
90
91
92
93
    label: Tensor tf.int32 containing the label.
    bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
      where each coordinate is [0, 1) and the coordinates are arranged as
      [ymin, xmin, ymax, xmax].
94
95
96
97
98
99
  """
  # Dense features in Example proto.
  feature_map = {
      'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
                                          default_value=''),
      'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
100
101
102
                                              default_value=-1),
      'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
                                             default_value=''),
103
  }
104
105
106
107
108
109
110
  sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
  # Sparse features in Example proto.
  feature_map.update(
      {k: sparse_float32 for k in ['image/object/bbox/xmin',
                                   'image/object/bbox/ymin',
                                   'image/object/bbox/xmax',
                                   'image/object/bbox/ymax']})
111

112
  features = tf.parse_single_example(example_serialized, feature_map)
113
  label = tf.cast(features['image/class/label'], dtype=tf.int32)
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
  xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
  ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
  xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
  ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)

  # Note that we impose an ordering of (y, x) just to make life difficult.
  bbox = tf.concat([ymin, xmin, ymax, xmax], 0)

  # Force the variable number of bounding boxes into the shape
  # [1, num_boxes, coords].
  bbox = tf.expand_dims(bbox, 0)
  bbox = tf.transpose(bbox, [0, 2, 1])

  return features['image/encoded'], label, bbox
129
130
131
132
133
134
135
136
137
138
139
140


def parse_record(raw_record, is_training):
  """Parses a record containing a training example of an image.

  The input record is parsed into a label and image, and the image is passed
  through preprocessing steps (cropping, flipping, and so on).

  Args:
    raw_record: scalar Tensor tf.string containing a serialized
      Example protocol buffer.
    is_training: A boolean denoting whether the input is for training.
141

142
143
  Returns:
    Tuple with processed image tensor and one-hot-encoded label tensor.
144
145
146
147
148
149
  """
  image_buffer, label, bbox = _parse_example_proto(raw_record)

  image = imagenet_preprocessing.preprocess_image(
      image_buffer=image_buffer,
      bbox=bbox,
150
151
      output_height=_DEFAULT_IMAGE_SIZE,
      output_width=_DEFAULT_IMAGE_SIZE,
152
      num_channels=_NUM_CHANNELS,
153
154
      is_training=is_training)

155
  label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)
156

157
  return image, label
158
159


160
161
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
             num_parallel_calls=1, multi_gpu=False):
162
  """Input function which provides batches for train or eval.
Karmel Allison's avatar
Karmel Allison committed
163

164
165
166
167
168
  Args:
    is_training: A boolean denoting whether the input is for training.
    data_dir: The directory containing the input data.
    batch_size: The number of samples per batch.
    num_epochs: The number of epochs to repeat the dataset.
169
170
171
172
173
174
    num_parallel_calls: The number of records that are processed in parallel.
      This can be optimized per data set but for generally homogeneous data
      sets, should be approximately the number of available CPU cores.
    multi_gpu: Whether this is run multi-GPU. Note that this is only required
      currently to handle the batch leftovers, and can be removed
      when that is handled directly by Estimator.
175
176
177
178
179
180

  Returns:
    A dataset that can be used for iteration.
  """
  filenames = get_filenames(is_training, data_dir)
  dataset = tf.data.Dataset.from_tensor_slices(filenames)
181

182
  if is_training:
183
184
    # Shuffle the input files
    dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
185

186
187
  num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']

188
  # Convert to individual records
189
  dataset = dataset.flat_map(tf.data.TFRecordDataset)
190

191
  return resnet_run_loop.process_record_dataset(
192
      dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
193
194
      num_epochs, num_parallel_calls, examples_per_epoch=num_images,
      multi_gpu=multi_gpu)
195
196
197


def get_synth_input_fn():
198
  return resnet_run_loop.get_synth_input_fn(
Karmel Allison's avatar
Karmel Allison committed
199
      _DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES)
200
201


202
203
204
###############################################################################
# Running the model
###############################################################################
205
class ImagenetModel(resnet_model.Model):
Karmel Allison's avatar
Karmel Allison committed
206
  """Model class with appropriate defaults for Imagenet data."""
207

208
  def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
209
210
               version=resnet_model.DEFAULT_VERSION,
               dtype=resnet_model.DEFAULT_DTYPE):
Neal Wu's avatar
Neal Wu committed
211
212
213
214
215
216
217
    """These are the parameters that work for Imagenet data.

    Args:
      resnet_size: The number of convolutional layers needed in the model.
      data_format: Either 'channels_first' or 'channels_last', specifying which
        data format to use when setting up the model.
      num_classes: The number of output classes needed from the model. This
218
        enables users to extend the same model to their own datasets.
219
220
      version: Integer representing which version of the ResNet network to use.
        See README for details. Valid values: [1, 2]
221
      dtype: The TensorFlow dtype to use for calculations.
Neal Wu's avatar
Neal Wu committed
222
    """
223
224
225

    # For bigger models, we want to use "bottleneck" layers
    if resnet_size < 50:
226
      bottleneck = False
227
228
      final_size = 512
    else:
229
      bottleneck = True
230
231
232
233
      final_size = 2048

    super(ImagenetModel, self).__init__(
        resnet_size=resnet_size,
234
        bottleneck=bottleneck,
235
        num_classes=num_classes,
236
237
238
239
240
241
242
243
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        block_sizes=_get_block_sizes(resnet_size),
        block_strides=[1, 2, 2, 2],
        final_size=final_size,
244
        version=version,
245
246
247
        data_format=data_format,
        dtype=dtype
    )
248
249
250


def _get_block_sizes(resnet_size):
Karmel Allison's avatar
Karmel Allison committed
251
252
253
  """Retrieve the size of each block_layer in the ResNet model.

  The number of block layers used for the Resnet model varies according
254
255
  to the size of the model. This helper grabs the layer set we want, throwing
  an error if a non-standard size has been selected.
Karmel Allison's avatar
Karmel Allison committed
256
257
258
259
260
261
262
263
264

  Args:
    resnet_size: The number of convolutional layers needed in the model.

  Returns:
    A list of block sizes to use in building the model.

  Raises:
    KeyError: if invalid resnet_size is received.
265
266
267
268
269
270
271
272
  """
  choices = {
      18: [2, 2, 2, 2],
      34: [3, 4, 6, 3],
      50: [3, 4, 6, 3],
      101: [3, 4, 23, 3],
      152: [3, 8, 36, 3],
      200: [3, 24, 36, 3]
273
274
  }

275
276
277
278
279
280
281
  try:
    return choices[resnet_size]
  except KeyError:
    err = ('Could not find layers for selected Resnet size.\n'
           'Size received: {}; sizes allowed: {}.'.format(
               resnet_size, choices.keys()))
    raise ValueError(err)
282
283


284
285
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
286
  learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
287
288
289
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
290

291
292
293
294
295
296
297
298
299
300
301
302
303
  return resnet_run_loop.resnet_model_fn(
      features=features,
      labels=labels,
      mode=mode,
      model_class=ImagenetModel,
      resnet_size=params['resnet_size'],
      weight_decay=1e-4,
      learning_rate_fn=learning_rate_fn,
      momentum=0.9,
      data_format=params['data_format'],
      version=params['version'],
      loss_scale=params['loss_scale'],
      loss_filter_fn=None,
304
      multi_gpu=params['multi_gpu'],
305
306
      dtype=params['dtype']
  )
307
308


309
310
311
312
313
def define_imagenet_flags():
  resnet_run_loop.define_resnet_flags(
      resnet_size_choices=['18', '34', '50', '101', '152', '200'])
  flags.adopt_module_key_flags(resnet_run_loop)
  flags_core.set_defaults(train_epochs=100)
314

315

316
317
318
def main(flags_obj):
  input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
                    or input_fn)
319
320

  resnet_run_loop.resnet_main(
321
      flags_obj, imagenet_model_fn, input_function,
322
      shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
323
324
325
326


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
327
328
  define_imagenet_flags()
  absl_app.run(main)