resnet_cifar_model.py 10.1 KB
Newer Older
1
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Shining Sun's avatar
Shining Sun committed
15
"""ResNet56 model for Keras adapted from tf.keras.applications.ResNet50.
16
17
18
19
20
21
22
23
24
25

# Reference:
- [Deep Residual Learning for Image Recognition](
    https://arxiv.org/abs/1512.03385)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

26
import functools
27
import tensorflow as tf
Toby Boyd's avatar
Toby Boyd committed
28
from tensorflow.python.keras import backend
29
from tensorflow.python.keras  import initializers
Toby Boyd's avatar
Toby Boyd committed
30
from tensorflow.python.keras import layers
31
from tensorflow.python.keras import regularizers
32
33


34
BATCH_NORM_DECAY = 0.997
35
BATCH_NORM_EPSILON = 1e-5
36
L2_WEIGHT_DECAY = 2e-4
37
38


Shining Sun's avatar
Shining Sun committed
39
40
41
42
43
44
def identity_building_block(input_tensor,
                            kernel_size,
                            filters,
                            stage,
                            block,
                            training=None):
45
46
47
48
49
50
51
52
  """The identity block is the block that has no conv layer at shortcut.

  Arguments:
    input_tensor: input tensor
    kernel_size: default 3, the kernel size of
        middle conv layer at main path
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
53
    block: current block label, used for generating layer names
Shining Sun's avatar
Shining Sun committed
54
55
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.
56
57
58
59
60

  Returns:
    Output tensor for the block.
  """
  filters1, filters2 = filters
61
  if backend.image_data_format() == 'channels_last':
62
63
64
65
66
67
    bn_axis = 3
  else:
    bn_axis = 1
  conv_name_base = 'res' + str(stage) + block + '_branch'
  bn_name_base = 'bn' + str(stage) + block + '_branch'

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  x = layers.Conv2D(filters1, kernel_size,
                    padding='same', use_bias=False,
                    kernel_initializer='he_normal',
                    kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                    name=conv_name_base + '2a')(input_tensor)
  x = layers.BatchNormalization(
      axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2a')(x, training=training)
  x = layers.Activation('relu')(x)

  x = layers.Conv2D(filters2, kernel_size,
                    padding='same', use_bias=False,
                    kernel_initializer='he_normal',
                    kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                    name=conv_name_base + '2b')(x)
  x = layers.BatchNormalization(
      axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2b')(x, training=training)

  x = layers.add([x, input_tensor])
  x = layers.Activation('relu')(x)
89
90
91
92
  return x


def conv_building_block(input_tensor,
Shining Sun's avatar
Shining Sun committed
93
94
95
96
97
98
                        kernel_size,
                        filters,
                        stage,
                        block,
                        strides=(2, 2),
                        training=None):
99
100
101
102
103
104
105
106
  """A block that has a conv layer at shortcut.

  Arguments:
    input_tensor: input tensor
    kernel_size: default 3, the kernel size of
        middle conv layer at main path
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
107
    block: current block label, used for generating layer names
108
    strides: Strides for the first conv layer in the block.
Shining Sun's avatar
Shining Sun committed
109
110
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

  Returns:
    Output tensor for the block.

  Note that from stage 3,
  the first conv layer at main path is with strides=(2, 2)
  And the shortcut should have strides=(2, 2) as well
  """
  filters1, filters2 = filters
  if tf.keras.backend.image_data_format() == 'channels_last':
    bn_axis = 3
  else:
    bn_axis = 1
  conv_name_base = 'res' + str(stage) + block + '_branch'
  bn_name_base = 'bn' + str(stage) + block + '_branch'

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  x = layers.Conv2D(filters1, kernel_size, strides=strides,
                    padding='same', use_bias=False,
                    kernel_initializer='he_normal',
                    kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                    name=conv_name_base + '2a')(input_tensor)
  x = layers.BatchNormalization(
      axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2a')(x, training=training)
  x = layers.Activation('relu')(x)

  x = layers.Conv2D(filters2, kernel_size, padding='same', use_bias=False,
                    kernel_initializer='he_normal',
                    kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                    name=conv_name_base + '2b')(x)
  x = layers.BatchNormalization(
      axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '2b')(x, training=training)

  shortcut = layers.Conv2D(filters2, (1, 1), strides=strides, use_bias=False,
                           kernel_initializer='he_normal',
                           kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                           name=conv_name_base + '1')(input_tensor)
  shortcut = layers.BatchNormalization(
      axis=bn_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON,
      name=bn_name_base + '1')(shortcut, training=training)

  x = layers.add([x, shortcut])
  x = layers.Activation('relu')(x)
155
156
157
  return x


158
159
160
161
162
163
164
165
def resnet_block(input_tensor,
                 size,
                 kernel_size,
                 filters,
                 stage,
                 conv_strides=(2, 2),
                 training=None):
  """A block which applies conv followed by multiple identity blocks.
166
167

  Arguments:
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    input_tensor: input tensor
    size: integer, number of constituent conv/identity building blocks.
    A conv block is applied once, followed by (size - 1) identity blocks.
    kernel_size: default 3, the kernel size of
        middle conv layer at main path
    filters: list of integers, the filters of 3 conv layer at main path
    stage: integer, current stage label, used for generating layer names
    conv_strides: Strides for the first conv layer in the block.
    training: Only used if training keras model with Estimator.  In other
      scenarios it is handled automatically.

  Returns:
    Output tensor after applying conv and identity blocks.
  """

  x = conv_building_block(input_tensor, kernel_size, filters, stage=stage,
                          strides=conv_strides, block='block_0',
                          training=training)
  for i in range(size - 1):
    x = identity_building_block(x, kernel_size, filters, stage=stage,
                                block='block_%d' % (i + 1), training=training)
  return x

191

192
193
194
195
196
197
198
199
200
201
def resnet(num_blocks, classes=10, training=None):
  """Instantiates the ResNet architecture.

  Arguments:
    num_blocks: integer, the number of conv/identity blocks in each block.
      The ResNet contains 3 blocks with each block containing one conv block
      followed by (layers_per_block - 1) number of idenity blocks. Each
      conv/idenity block has 2 convolutional layers. With the input
      convolutional layer and the pooling layer towards the end, this brings
      the total size of the network to (6*num_blocks + 2)
Shining Sun's avatar
Shining Sun committed
202
203
204
    classes: optional number of classes to classify images into
    training: Only used if training keras model with Estimator.  In other
    scenarios it is handled automatically.
205
206

  Returns:
Shining Sun's avatar
Shining Sun committed
207
    A Keras model instance.
208
  """
209

210
211
212
  input_shape = (32, 32, 3)
  img_input = layers.Input(shape=input_shape)

Shining Sun's avatar
Shining Sun committed
213
  if backend.image_data_format() == 'channels_first':
214
215
    x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
                      name='transpose')(img_input)
216
    bn_axis = 1
Toby Boyd's avatar
Toby Boyd committed
217
  else:  # channel_last
218
    x = img_input
Shining Sun's avatar
Shining Sun committed
219
    bn_axis = 3
220

221
222
223
224
225
226
227
228
229
230
231
232
  x = layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
  x = layers.Conv2D(16, (3, 3),
                    strides=(1, 1),
                    padding='valid', use_bias=False,
                    kernel_initializer='he_normal',
                    kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
                    name='conv1')(x)
  x = layers.BatchNormalization(axis=bn_axis,
                                momentum=BATCH_NORM_DECAY,
                                epsilon=BATCH_NORM_EPSILON,
                                name='bn_conv1',)(x, training=training)
  x = layers.Activation('relu')(x)
233

234
235
236
237
238
239
240
241
  x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[16, 16],
                   stage=2, conv_strides=(1, 1), training=training)

  x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[32, 32],
                   stage=3, conv_strides=(2, 2), training=training)

  x = resnet_block(x, size=num_blocks, kernel_size=3, filters=[64, 64],
                   stage=4, conv_strides=(2, 2), training=training)
Shining Sun's avatar
Shining Sun committed
242

243
244
  rm_axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
  x = layers.Lambda(lambda x: backend.mean(x, rm_axes), name='reduce_mean')(x)
245
246
247
  x = layers.Dense(classes,
                   activation='softmax',
                   kernel_initializer=initializers.RandomNormal(stddev=0.01),
248
                   kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
249
                   bias_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
250
                   name='fc10')(x)
251
252
253
254
255
256

  inputs = img_input
  # Create model.
  model = tf.keras.models.Model(inputs, x, name='resnet56')

  return model
257
258
259
260
261
262


resnet20 = functools.partial(resnet, num_blocks=3)
resnet32 = functools.partial(resnet, num_blocks=5)
resnet56 = functools.partial(resnet, num_blocks=9)
resnet10 = functools.partial(resnet, num_blocks=110)