resnet_cifar_model.py 10.2 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Shining Sun's avatar
Shining Sun committed
15
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
16
17
18
19
20
21
22
23
24
25

# Reference:
- [Deep Residual Learning for Image Recognition](
    https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

26
import functools
Hongkun Yu's avatar
Hongkun Yu committed
27

28
29
import tensorflow as tf

30
BATCH_NORM_DECAY = 0.997
31
BATCH_NORM_EPSILON = 1e-5
32
L2_WEIGHT_DECAY = 2e-4
33
34


Shining Sun's avatar
Shining Sun committed
35
36
37
38
39
40
def identity_building_block(input_tensor,
                            kernel_size,
                            filters,
                            stage,
                            block,
                            training=None):
41
42
43
44
  """The identity block is the block that has no conv layer at shortcut.

  Arguments:
    input_tensor: input tensor
Hongkun Yu's avatar
Hongkun Yu committed
45
    kernel_size: default 3, the kernel size of middle conv layer at main path
46
47
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
48
    block: current block label, used for generating layer names
Shining Sun's avatar
Shining Sun committed
49
50
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.
51
52
53
54
55

  Returns:
    Output tensor for the block.
  """
  filters1, filters2 = filters
Scott Zhu's avatar
Scott Zhu committed
56
  if tf.keras.backend.image_data_format() == 'channels_last':
57
58
59
60
61
62
    bn_axis = 3
  else:
    bn_axis = 1
  conv_name_base = 'res' + str(stage) + block + '_branch'
  bn_name_base = 'bn' + str(stage) + block + '_branch'

Scott Zhu's avatar
Scott Zhu committed
63
  x = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
64
65
66
67
68
      filters1,
      kernel_size,
      padding='same',
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
69
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
70
71
      name=conv_name_base + '2a')(
          input_tensor)
Scott Zhu's avatar
Scott Zhu committed
72
  x = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
73
74
75
76
77
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2a')(
          x, training=training)
Scott Zhu's avatar
Scott Zhu committed
78
  x = tf.keras.layers.Activation('relu')(x)
79

Scott Zhu's avatar
Scott Zhu committed
80
  x = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
81
82
83
84
85
      filters2,
      kernel_size,
      padding='same',
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
86
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
87
88
      name=conv_name_base + '2b')(
          x)
Scott Zhu's avatar
Scott Zhu committed
89
  x = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
93
94
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2b')(
          x, training=training)
95

Scott Zhu's avatar
Scott Zhu committed
96
97
  x = tf.keras.layers.add([x, input_tensor])
  x = tf.keras.layers.Activation('relu')(x)
98
99
100
101
  return x


def conv_building_block(input_tensor,
Shining Sun's avatar
Shining Sun committed
102
103
104
105
106
107
                        kernel_size,
                        filters,
                        stage,
                        block,
                        strides=(2, 2),
                        training=None):
108
109
110
111
  """A block that has a conv layer at shortcut.

  Arguments:
    input_tensor: input tensor
Hongkun Yu's avatar
Hongkun Yu committed
112
    kernel_size: default 3, the kernel size of middle conv layer at main path
113
114
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
115
    block: current block label, used for generating layer names
116
    strides: Strides for the first conv layer in the block.
Shining Sun's avatar
Shining Sun committed
117
118
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

  Returns:
    Output tensor for the block.

  Note that from stage 3,
  the first conv layer at main path is with strides=(2, 2)
  And the shortcut should have strides=(2, 2) as well
  """
  filters1, filters2 = filters
  if tf.keras.backend.image_data_format() == 'channels_last':
    bn_axis = 3
  else:
    bn_axis = 1
  conv_name_base = 'res' + str(stage) + block + '_branch'
  bn_name_base = 'bn' + str(stage) + block + '_branch'

Scott Zhu's avatar
Scott Zhu committed
135
  x = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
136
137
138
139
140
141
      filters1,
      kernel_size,
      strides=strides,
      padding='same',
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
142
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
143
144
      name=conv_name_base + '2a')(
          input_tensor)
Scott Zhu's avatar
Scott Zhu committed
145
  x = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
146
147
148
149
150
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2a')(
          x, training=training)
Scott Zhu's avatar
Scott Zhu committed
151
  x = tf.keras.layers.Activation('relu')(x)
152

Scott Zhu's avatar
Scott Zhu committed
153
  x = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
154
155
156
157
158
      filters2,
      kernel_size,
      padding='same',
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
159
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
160
161
      name=conv_name_base + '2b')(
          x)
Scott Zhu's avatar
Scott Zhu committed
162
  x = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
163
164
165
166
167
168
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2b')(
          x, training=training)

Scott Zhu's avatar
Scott Zhu committed
169
  shortcut = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
170
171
172
173
      filters2, (1, 1),
      strides=strides,
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
174
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
175
176
      name=conv_name_base + '1')(
          input_tensor)
Scott Zhu's avatar
Scott Zhu committed
177
  shortcut = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
178
179
180
181
182
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '1')(
          shortcut, training=training)
183

Scott Zhu's avatar
Scott Zhu committed
184
185
  x = tf.keras.layers.add([x, shortcut])
  x = tf.keras.layers.Activation('relu')(x)
186
187
188
  return x


189
190
191
192
193
194
195
196
def resnet_block(input_tensor,
                 size,
                 kernel_size,
                 filters,
                 stage,
                 conv_strides=(2, 2),
                 training=None):
  """A block which applies conv followed by multiple identity blocks.
197
198

  Arguments:
199
    input_tensor: input tensor
Hongkun Yu's avatar
Hongkun Yu committed
200
201
202
    size: integer, number of constituent conv/identity building blocks. A conv
      block is applied once, followed by (size - 1) identity blocks.
    kernel_size: default 3, the kernel size of middle conv layer at main path
203
204
205
206
207
208
209
210
211
212
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
    conv_strides: Strides for the first conv layer in the block.
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.

  Returns:
    Output tensor after applying conv and identity blocks.
  """

Hongkun Yu's avatar
Hongkun Yu committed
213
214
215
216
217
218
219
220
  x = conv_building_block(
      input_tensor,
      kernel_size,
      filters,
      stage=stage,
      strides=conv_strides,
      block='block_0',
      training=training)
221
  for i in range(size - 1):
Hongkun Yu's avatar
Hongkun Yu committed
222
223
224
225
226
227
228
    x = identity_building_block(
        x,
        kernel_size,
        filters,
        stage=stage,
        block='block_%d' % (i + 1),
        training=training)
229
230
  return x

231

232
233
234
235
def resnet(num_blocks, classes=10, training=None):
  """Instantiates the ResNet architecture.

  Arguments:
Hongkun Yu's avatar
Hongkun Yu committed
236
237
    num_blocks: integer, the number of conv/identity blocks in each block. The
      ResNet contains 3 blocks with each block containing one conv block
238
239
      followed by (layers_per_block - 1) number of idenity blocks. Each
      conv/idenity block has 2 convolutional layers. With the input
Hongkun Yu's avatar
Hongkun Yu committed
240
241
      convolutional layer and the pooling layer towards the end, this brings the
      total size of the network to (6*num_blocks + 2)
Shining Sun's avatar
Shining Sun committed
242
243
    classes: optional number of classes to classify images into
    training: Only used if training keras model with Estimator.  In other
Hongkun Yu's avatar
Hongkun Yu committed
244
      scenarios it is handled automatically.
245
246

  Returns:
Shining Sun's avatar
Shining Sun committed
247
    A Keras model instance.
248
  """
249

250
  input_shape = (32, 32, 3)
Scott Zhu's avatar
Scott Zhu committed
251
  img_input = tf.keras.Input(shape=input_shape)
252

Scott Zhu's avatar
Scott Zhu committed
253
254
255
  if tf.keras.backend.image_data_format() == 'channels_first':
    x = tf.keras.layers.Lambda(
        lambda x: tf.keras.backend.permute_dimensions(x, (0, 3, 1, 2)),
Hongkun Yu's avatar
Hongkun Yu committed
256
257
        name='transpose')(
            img_input)
258
    bn_axis = 1
Toby Boyd's avatar
Toby Boyd committed
259
  else:  # channel_last
260
    x = img_input
Shining Sun's avatar
Shining Sun committed
261
    bn_axis = 3
262

Scott Zhu's avatar
Scott Zhu committed
263
264
  x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
  x = tf.keras.layers.Conv2D(
Hongkun Yu's avatar
Hongkun Yu committed
265
266
267
268
269
      16, (3, 3),
      strides=(1, 1),
      padding='valid',
      use_bias=False,
      kernel_initializer='he_normal',
Scott Zhu's avatar
Scott Zhu committed
270
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
271
272
      name='conv1')(
          x)
Scott Zhu's avatar
Scott Zhu committed
273
  x = tf.keras.layers.BatchNormalization(
Hongkun Yu's avatar
Hongkun Yu committed
274
275
276
277
278
      axis=bn_axis,
      momentum=BATCH_NORM_DECAY,
      epsilon=BATCH_NORM_EPSILON,
      name='bn_conv1',
  )(x, training=training)
Scott Zhu's avatar
Scott Zhu committed
279
  x = tf.keras.layers.Activation('relu')(x)
280

Hongkun Yu's avatar
Hongkun Yu committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
  x = resnet_block(
      x,
      size=num_blocks,
      kernel_size=3,
      filters=[16, 16],
      stage=2,
      conv_strides=(1, 1),
      training=training)

  x = resnet_block(
      x,
      size=num_blocks,
      kernel_size=3,
      filters=[32, 32],
      stage=3,
      conv_strides=(2, 2),
      training=training)

  x = resnet_block(
      x,
      size=num_blocks,
      kernel_size=3,
      filters=[64, 64],
      stage=4,
      conv_strides=(2, 2),
      training=training)
Shining Sun's avatar
Shining Sun committed
307

Scott Zhu's avatar
Scott Zhu committed
308
309
310
311
312
313
314
  if tf.keras.backend.image_data_format() == 'channels_last':
    rm_axes = [1, 2]
  else:
    rm_axes = [2, 3]
  x = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.mean(x, rm_axes), name='reduce_mean')(x)
  x = tf.keras.layers.Dense(
Hongkun Yu's avatar
Hongkun Yu committed
315
316
      classes,
      activation='softmax',
Scott Zhu's avatar
Scott Zhu committed
317
318
319
320
      kernel_initializer=tf.keras.initializers.RandomNormal(
          stddev=0.01),
      kernel_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
      bias_regularizer=tf.keras.regularizers.L2(L2_WEIGHT_DECAY),
Hongkun Yu's avatar
Hongkun Yu committed
321
322
      name='fc10')(
          x)
323
324
325
326
327
328

  inputs = img_input
  # Create model.
  model = tf.keras.models.Model(inputs, x, name='resnet56')

  return model
329
330
331
332
333
334


resnet20 = functools.partial(resnet, num_blocks=3)
resnet32 = functools.partial(resnet, num_blocks=5)
resnet56 = functools.partial(resnet, num_blocks=9)
resnet10 = functools.partial(resnet, num_blocks=110)