imagenet_main.py 6.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Runs a ResNet model on the ImageNet dataset."""
16
17
18
19
20
21

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
22
import sys
23
24
25

import tensorflow as tf

Karmel Allison's avatar
Karmel Allison committed
26
import resnet
27
28
import vgg_preprocessing

29
_DEFAULT_IMAGE_SIZE = 224
30
_NUM_CHANNELS = 3
31
_NUM_CLASSES = 1001
32

33
34
35
36
_NUM_IMAGES = {
    'train': 1281167,
    'validation': 50000,
}
37

Neal Wu's avatar
Neal Wu committed
38
_FILE_SHUFFLE_BUFFER = 1024
39
_SHUFFLE_BUFFER = 1500
40

41

42
43
44
###############################################################################
# Data processing
###############################################################################
45
def filenames(is_training, data_dir):
46
47
48
  """Return filenames for dataset."""
  if is_training:
    return [
49
        os.path.join(data_dir, 'train-%05d-of-01024' % i)
Neal Wu's avatar
Neal Wu committed
50
        for i in range(1024)]
51
52
  else:
    return [
53
        os.path.join(data_dir, 'validation-%05d-of-00128' % i)
Neal Wu's avatar
Neal Wu committed
54
        for i in range(128)]
55
56


57
def parse_record(raw_record, is_training):
58
  """Parse an ImageNet record from `value`."""
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  keys_to_features = {
      'image/encoded':
          tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format':
          tf.FixedLenFeature((), tf.string, default_value='jpeg'),
      'image/class/label':
          tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
      'image/class/text':
          tf.FixedLenFeature([], dtype=tf.string, default_value=''),
      'image/object/bbox/xmin':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/ymin':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/xmax':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/bbox/ymax':
          tf.VarLenFeature(dtype=tf.float32),
      'image/object/class/label':
          tf.VarLenFeature(dtype=tf.int64),
  }

80
  parsed = tf.parse_single_example(raw_record, keys_to_features)
81

82
83
84
85
86
  image = tf.image.decode_image(
      tf.reshape(parsed['image/encoded'], shape=[]),
      _NUM_CHANNELS)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)

87
  image = vgg_preprocessing.preprocess_image(
88
      image=image,
89
90
      output_height=_DEFAULT_IMAGE_SIZE,
      output_width=_DEFAULT_IMAGE_SIZE,
91
92
93
94
95
96
      is_training=is_training)

  label = tf.cast(
      tf.reshape(parsed['image/class/label'], shape=[]),
      dtype=tf.int32)

97
  return image, tf.one_hot(label, _NUM_CLASSES)
98
99


100
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
101
  """Input function which provides batches for train or eval."""
102
103
  dataset = tf.data.Dataset.from_tensor_slices(
      filenames(is_training, data_dir))
104

105
  if is_training:
Neal Wu's avatar
Neal Wu committed
106
    dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
107

108
  dataset = dataset.flat_map(tf.data.TFRecordDataset)
109
  dataset = dataset.map(lambda value: parse_record(value, is_training),
110
111
                        num_parallel_calls=5)
  dataset = dataset.prefetch(batch_size)
112
113

  if is_training:
114
115
116
    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes have better performance.
    dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
117

Neal Wu's avatar
Neal Wu committed
118
119
120
121
122
123
  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_one_shot_iterator()
124
  images, labels = iterator.get_next()
125
126
127
  return images, labels


128
129
130
###############################################################################
# Running the model
###############################################################################
Karmel Allison's avatar
Karmel Allison committed
131
class ImagenetModel(resnet.Model):
132
133
134
135
136
137
  def __init__(self, resnet_size, data_format=None):
    """These are the parameters that work for Imagenet data.
    """

    # For bigger models, we want to use "bottleneck" layers
    if resnet_size < 50:
Karmel Allison's avatar
Karmel Allison committed
138
      block_fn = resnet.building_block
139
140
      final_size = 512
    else:
Karmel Allison's avatar
Karmel Allison committed
141
      block_fn = resnet.bottleneck_block
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
      final_size = 2048

    super(ImagenetModel, self).__init__(
        resnet_size=resnet_size,
        num_classes=_NUM_CLASSES,
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        second_pool_size=7,
        second_pool_stride=1,
        block_fn=block_fn,
        block_sizes=_get_block_sizes(resnet_size),
        block_strides=[1, 2, 2, 2],
        final_size=final_size,
        data_format=data_format)


def _get_block_sizes(resnet_size):
  """The number of block layers used for the Resnet model varies according
  to the size of the model. This helper grabs the layer set we want, throwing
  an error if a non-standard size has been selected.
  """
  choices = {
      18: [2, 2, 2, 2],
      34: [3, 4, 6, 3],
      50: [3, 4, 6, 3],
      101: [3, 4, 23, 3],
      152: [3, 8, 36, 3],
      200: [3, 24, 36, 3]
173
174
  }

175
176
177
178
179
180
181
  try:
    return choices[resnet_size]
  except KeyError:
    err = ('Could not find layers for selected Resnet size.\n'
           'Size received: {}; sizes allowed: {}.'.format(
               resnet_size, choices.keys()))
    raise ValueError(err)
182
183


184
185
def imagenet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
Karmel Allison's avatar
Karmel Allison committed
186
  learning_rate_fn = resnet.learning_rate_with_decay(
187
188
189
      batch_size=params['batch_size'], batch_denom=256,
      num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
      decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
190

Karmel Allison's avatar
Karmel Allison committed
191
192
193
194
195
196
197
  return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
                                resnet_size=params['resnet_size'],
                                weight_decay=1e-4,
                                learning_rate_fn=learning_rate_fn,
                                momentum=0.9,
                                data_format=params['data_format'],
                                loss_filter_fn=None)
198
199
200


def main(unused_argv):
Karmel Allison's avatar
Karmel Allison committed
201
  resnet.resnet_main(FLAGS, imagenet_model_fn, input_fn)
202
203
204
205


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
206

Karmel Allison's avatar
Karmel Allison committed
207
  parser = resnet.ResnetArgParser(
208
      resnet_size_choices=[18, 34, 50, 101, 152, 200])
209
210
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(argv=[sys.argv[0]] + unparsed)