imagenet_main.py 6.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the ImageNet dataset."""
16
17
18
19
20
21

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
22
import sys
23
24
25
26

import tensorflow as tf

import resnet_model
27
import resnet_shared
28
29
import vgg_preprocessing

30
_DEFAULT_IMAGE_SIZE = 224
31
_NUM_CHANNELS = 3
32
_NUM_CLASSES = 1001
33

34
35
36
37
_NUM_IMAGES = {
    'train': 1281167,
    'validation': 50000,
}
38

Neal Wu's avatar
Neal Wu committed
39
_FILE_SHUFFLE_BUFFER = 1024
40
_SHUFFLE_BUFFER = 1500
41

42

43
44
45
###############################################################################
# Data processing
###############################################################################
46
def filenames(is_training, data_dir):
47
48
49
  """Return filenames for dataset."""
  if is_training:
    return [
50
        os.path.join(data_dir, 'train-%05d-of-01024' % i)
Neal Wu's avatar
Neal Wu committed
51
        for i in range(1024)]
52
53
  else:
    return [
54
        os.path.join(data_dir, 'validation-%05d-of-00128' % i)
Neal Wu's avatar
Neal Wu committed
55
        for i in range(128)]
56
57


58
def parse_record(raw_record, is_training):
59
  """Parse an ImageNet record from `value`."""
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  keys_to_features = {
      'image/encoded':
          tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format':
          tf.FixedLenFeature((), tf.string, default_value='jpeg'),
      'image/class/label':
          tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
      'image/class/text':
          tf.FixedLenFeature([], dtype=tf.string, default_value=''),
      'image/object/bbox/xmin':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/ymin':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/xmax':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/ymax':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/class/label':
          tf.VarLenFeature(dtype=tf.int64),
  }

81
  parsed = tf.parse_single_example(raw_record, keys_to_features)
82

83
84
85
86
87
  image = tf.image.decode_image(
      tf.reshape(parsed['image/encoded'], shape=[]),
      _NUM_CHANNELS)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)

88
  image = vgg_preprocessing.preprocess_image(
89
      image=image,
90
91
      output_height=_DEFAULT_IMAGE_SIZE,
      output_width=_DEFAULT_IMAGE_SIZE,
92
93
94
95
96
97
      is_training=is_training)

  label = tf.cast(
      tf.reshape(parsed['image/class/label'], shape=[]),
      dtype=tf.int32)

98
  return image, tf.one_hot(label, _NUM_CLASSES)
99
100


101
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
102
  """Input function which provides batches for train or eval."""
103
104
  dataset = tf.data.Dataset.from_tensor_slices(
      filenames(is_training, data_dir))
105

106
  if is_training:
Neal Wu's avatar
Neal Wu committed
107
    dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
108

109
  dataset = dataset.flat_map(tf.data.TFRecordDataset)
110
  dataset = dataset.map(lambda value: parse_record(value, is_training),
111
112
                        num_parallel_calls=5)
  dataset = dataset.prefetch(batch_size)
113
114

  if is_training:
115
116
117
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes have better performance.
    dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
118

Neal Wu's avatar
Neal Wu committed
119
120
121
122
123
124
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_one_shot_iterator()
125
  images, labels = iterator.get_next()
126
127
128
  return images, labels


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
###############################################################################
# Running the model
###############################################################################
class ImagenetModel(resnet_model.Model):
  def __init__(self, resnet_size, data_format=None):
    """These are the parameters that work for Imagenet data.
    """

    # For bigger models, we want to use "bottleneck" layers
    if resnet_size < 50:
      block_fn = resnet_model.building_block
      final_size = 512
    else:
      block_fn = resnet_model.bottleneck_block
      final_size = 2048

    super(ImagenetModel, self).__init__(
        resnet_size=resnet_size,
        num_classes=_NUM_CLASSES,
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        second_pool_size=7,
        second_pool_stride=1,
        block_fn=block_fn,
        block_sizes=_get_block_sizes(resnet_size),
        block_strides=[1, 2, 2, 2],
        final_size=final_size,
        data_format=data_format)


def _get_block_sizes(resnet_size):
  """The number of block layers used for the Resnet model varies according
  to the size of the model. This helper grabs the layer set we want, throwing
  an error if a non-standard size has been selected.
  """
  choices = {
      18: [2, 2, 2, 2],
      34: [3, 4, 6, 3],
      50: [3, 4, 6, 3],
      101: [3, 4, 23, 3],
      152: [3, 8, 36, 3],
      200: [3, 24, 36, 3]
174
175
  }

176
177
178
179
180
181
182
  try:
    return choices[resnet_size]
  except KeyError:
    err = ('Could not find layers for selected Resnet size.\n'
           'Size received: {}; sizes allowed: {}.'.format(
               resnet_size, choices.keys()))
    raise ValueError(err)
183
184


185
186
187
188
189
190
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  learning_rate_fn = resnet_shared.learning_rate_with_decay(
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
191

192
193
194
195
196
197
198
  return resnet_shared.resnet_model_fn(features, labels, mode, ImagenetModel,
                                       resnet_size=params['resnet_size'],
                                       weight_decay=1e-4,
                                       learning_rate_fn=learning_rate_fn,
                                       momentum=0.9,
                                       data_format=params['data_format'],
                                       loss_filter_fn=None)
199
200
201


def main(unused_argv):
202
  resnet_shared.resnet_main(FLAGS, imagenet_model_fn, input_fn)
203
204
205
206


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
207
208
209

  parser = resnet_shared.ResnetArgParser(
      resnet_size_choices=[18, 34, 50, 101, 152, 200])
210
211
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(argv=[sys.argv[0]] + unparsed)