inception_resnet_v2.py 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the definition of the Inception Resnet V2 architecture.

As described in http://arxiv.org/abs/1602.07261.

  Inception-v4, Inception-ResNet and the Impact of Residual Connections
    on Learning
  Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf
29
from tensorflow.contrib import slim as contrib_slim
30

31
slim = contrib_slim
32
33
34
35


def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  """Builds the 35x35 resnet block."""
36
37
  with tf.compat.v1.variable_scope(scope, 'Block35', [net], reuse=reuse):
    with tf.compat.v1.variable_scope('Branch_0'):
38
      tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
39
    with tf.compat.v1.variable_scope('Branch_1'):
40
41
      tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
      tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
42
    with tf.compat.v1.variable_scope('Branch_2'):
43
44
45
      tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
      tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
      tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
46
    mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])
47
48
    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
                     activation_fn=None, scope='Conv2d_1x1')
49
50
51
52
53
54
    scaled_up = up * scale
    if activation_fn == tf.nn.relu6:
      # Use clip_by_value to simulate bandpass activation.
      scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)

    net += scaled_up
55
56
57
58
59
60
61
    if activation_fn:
      net = activation_fn(net)
  return net


def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  """Builds the 17x17 resnet block."""
62
63
  with tf.compat.v1.variable_scope(scope, 'Block17', [net], reuse=reuse):
    with tf.compat.v1.variable_scope('Branch_0'):
64
      tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
65
    with tf.compat.v1.variable_scope('Branch_1'):
66
67
68
69
70
      tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
      tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
                                  scope='Conv2d_0b_1x7')
      tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
                                  scope='Conv2d_0c_7x1')
71
    mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
72
73
    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
                     activation_fn=None, scope='Conv2d_1x1')
74
75
76
77
78
79
80

    scaled_up = up * scale
    if activation_fn == tf.nn.relu6:
      # Use clip_by_value to simulate bandpass activation.
      scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)

    net += scaled_up
81
82
83
84
85
86
87
    if activation_fn:
      net = activation_fn(net)
  return net


def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  """Builds the 8x8 resnet block."""
88
89
  with tf.compat.v1.variable_scope(scope, 'Block8', [net], reuse=reuse):
    with tf.compat.v1.variable_scope('Branch_0'):
90
      tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
91
    with tf.compat.v1.variable_scope('Branch_1'):
92
93
94
95
96
      tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
      tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
                                  scope='Conv2d_0b_1x3')
      tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
                                  scope='Conv2d_0c_3x1')
97
    mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
98
99
    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
                     activation_fn=None, scope='Conv2d_1x1')
100
101
102
103
104
105
106

    scaled_up = up * scale
    if activation_fn == tf.nn.relu6:
      # Use clip_by_value to simulate bandpass activation.
      scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)

    net += scaled_up
107
108
109
110
111
    if activation_fn:
      net = activation_fn(net)
  return net


112
113
114
115
def inception_resnet_v2_base(inputs,
                             final_endpoint='Conv2d_7b_1x1',
                             output_stride=16,
                             align_feature_maps=False,
116
117
                             scope=None,
                             activation_fn=tf.nn.relu):
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  """Inception model from  http://arxiv.org/abs/1602.07261.

  Constructs an Inception Resnet v2 network from inputs to the given final
  endpoint. This method can construct the network up to the final inception
  block Conv2d_7b_1x1.

  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    final_endpoint: specifies the endpoint to construct the network up to. It
      can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
      'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
      'Mixed_5b', 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
    output_stride: A scalar that specifies the requested ratio of input to
      output spatial resolution. Only supports 8 and 16.
    align_feature_maps: When true, changes all the VALID paddings in the network
      to SAME padding so that the feature maps are aligned.
    scope: Optional variable_scope.
135
    activation_fn: Activation function for block scopes.
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

  Returns:
    tensor_out: output tensor corresponding to the final_endpoint.
    end_points: a set of activations for external use, for example summaries or
                losses.

  Raises:
    ValueError: if final_endpoint is not set to one of the predefined values,
      or if the output_stride is not 8 or 16, or if the output_stride is 8 and
      we request an end point after 'PreAuxLogits'.
  """
  if output_stride != 8 and output_stride != 16:
    raise ValueError('output_stride must be 8 or 16.')

  padding = 'SAME' if align_feature_maps else 'VALID'

  end_points = {}

  def add_and_check_final(name, net):
    end_points[name] = net
    return name == final_endpoint

158
  with tf.compat.v1.variable_scope(scope, 'InceptionResnetV2', [inputs]):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
                        stride=1, padding='SAME'):
      # 149 x 149 x 32
      net = slim.conv2d(inputs, 32, 3, stride=2, padding=padding,
                        scope='Conv2d_1a_3x3')
      if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points

      # 147 x 147 x 32
      net = slim.conv2d(net, 32, 3, padding=padding,
                        scope='Conv2d_2a_3x3')
      if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
      # 147 x 147 x 64
      net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
      if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
      # 73 x 73 x 64
      net = slim.max_pool2d(net, 3, stride=2, padding=padding,
                            scope='MaxPool_3a_3x3')
      if add_and_check_final('MaxPool_3a_3x3', net): return net, end_points
      # 73 x 73 x 80
      net = slim.conv2d(net, 80, 1, padding=padding,
                        scope='Conv2d_3b_1x1')
      if add_and_check_final('Conv2d_3b_1x1', net): return net, end_points
      # 71 x 71 x 192
      net = slim.conv2d(net, 192, 3, padding=padding,
                        scope='Conv2d_4a_3x3')
      if add_and_check_final('Conv2d_4a_3x3', net): return net, end_points
      # 35 x 35 x 192
      net = slim.max_pool2d(net, 3, stride=2, padding=padding,
                            scope='MaxPool_5a_3x3')
      if add_and_check_final('MaxPool_5a_3x3', net): return net, end_points

      # 35 x 35 x 320
191
192
      with tf.compat.v1.variable_scope('Mixed_5b'):
        with tf.compat.v1.variable_scope('Branch_0'):
193
          tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
194
        with tf.compat.v1.variable_scope('Branch_1'):
195
196
197
          tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
          tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
                                      scope='Conv2d_0b_5x5')
198
        with tf.compat.v1.variable_scope('Branch_2'):
199
200
201
202
203
          tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
          tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
                                      scope='Conv2d_0b_3x3')
          tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
                                      scope='Conv2d_0c_3x3')
204
        with tf.compat.v1.variable_scope('Branch_3'):
205
206
207
208
209
210
211
212
213
          tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
                                       scope='AvgPool_0a_3x3')
          tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
                                     scope='Conv2d_0b_1x1')
        net = tf.concat(
            [tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3)

      if add_and_check_final('Mixed_5b', net): return net, end_points
      # TODO(alemi): Register intermediate endpoints
214
215
      net = slim.repeat(net, 10, block35, scale=0.17,
                        activation_fn=activation_fn)
216
217
218
219
220

      # 17 x 17 x 1088 if output_stride == 8,
      # 33 x 33 x 1088 if output_stride == 16
      use_atrous = output_stride == 8

221
222
      with tf.compat.v1.variable_scope('Mixed_6a'):
        with tf.compat.v1.variable_scope('Branch_0'):
223
224
225
          tower_conv = slim.conv2d(net, 384, 3, stride=1 if use_atrous else 2,
                                   padding=padding,
                                   scope='Conv2d_1a_3x3')
226
        with tf.compat.v1.variable_scope('Branch_1'):
227
228
229
230
231
232
233
          tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
          tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
                                      scope='Conv2d_0b_3x3')
          tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
                                      stride=1 if use_atrous else 2,
                                      padding=padding,
                                      scope='Conv2d_1a_3x3')
234
        with tf.compat.v1.variable_scope('Branch_2'):
235
236
237
238
239
240
241
242
243
          tower_pool = slim.max_pool2d(net, 3, stride=1 if use_atrous else 2,
                                       padding=padding,
                                       scope='MaxPool_1a_3x3')
        net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)

      if add_and_check_final('Mixed_6a', net): return net, end_points

      # TODO(alemi): register intermediate endpoints
      with slim.arg_scope([slim.conv2d], rate=2 if use_atrous else 1):
244
245
        net = slim.repeat(net, 20, block17, scale=0.10,
                          activation_fn=activation_fn)
246
247
248
249
250
251
252
253
      if add_and_check_final('PreAuxLogits', net): return net, end_points

      if output_stride == 8:
        # TODO(gpapan): Properly support output_stride for the rest of the net.
        raise ValueError('output_stride==8 is only supported up to the '
                         'PreAuxlogits end_point for now.')

      # 8 x 8 x 2080
254
255
      with tf.compat.v1.variable_scope('Mixed_7a'):
        with tf.compat.v1.variable_scope('Branch_0'):
256
257
258
259
          tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
          tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
                                     padding=padding,
                                     scope='Conv2d_1a_3x3')
260
        with tf.compat.v1.variable_scope('Branch_1'):
261
262
263
264
          tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
          tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
                                      padding=padding,
                                      scope='Conv2d_1a_3x3')
265
        with tf.compat.v1.variable_scope('Branch_2'):
266
267
268
269
270
271
          tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
          tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
                                      scope='Conv2d_0b_3x3')
          tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
                                      padding=padding,
                                      scope='Conv2d_1a_3x3')
272
        with tf.compat.v1.variable_scope('Branch_3'):
273
274
275
276
277
278
279
280
281
          tower_pool = slim.max_pool2d(net, 3, stride=2,
                                       padding=padding,
                                       scope='MaxPool_1a_3x3')
        net = tf.concat(
            [tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3)

      if add_and_check_final('Mixed_7a', net): return net, end_points

      # TODO(alemi): register intermediate endpoints
282
      net = slim.repeat(net, 9, block8, scale=0.20, activation_fn=activation_fn)
283
284
285
286
287
288
289
290
291
      net = block8(net, activation_fn=None)

      # 8 x 8 x 1536
      net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
      if add_and_check_final('Conv2d_7b_1x1', net): return net, end_points

    raise ValueError('final_endpoint (%s) not recognized', final_endpoint)


292
293
294
def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
                        dropout_keep_prob=0.8,
                        reuse=None,
295
                        scope='InceptionResnetV2',
296
297
                        create_aux_logits=True,
                        activation_fn=tf.nn.relu):
298
299
300
301
  """Creates the Inception Resnet V2 model.

  Args:
    inputs: a 4-D tensor of size [batch_size, height, width, 3].
302
303
304
305
306
      Dimension batch_size may be undefined. If create_aux_logits is false,
      also height and width may be undefined.
    num_classes: number of predicted classes. If 0 or None, the logits layer
      is omitted and the input features to the logits layer (before  dropout)
      are returned instead.
307
308
309
310
311
    is_training: whether is training or not.
    dropout_keep_prob: float, the fraction to keep before final layer.
    reuse: whether or not the network and its variables should be reused. To be
      able to reuse 'scope' must be given.
    scope: Optional variable_scope.
312
    create_aux_logits: Whether to include the auxilliary logits.
313
    activation_fn: Activation function for conv2d.
314
315

  Returns:
316
317
318
    net: the output of the logits layer (if num_classes is a non-zero integer),
      or the non-dropped-out input to the logits layer (if num_classes is 0 or
      None).
319
320
321
322
    end_points: the set of end_points from the inception model.
  """
  end_points = {}

323
324
  with tf.compat.v1.variable_scope(
      scope, 'InceptionResnetV2', [inputs], reuse=reuse) as scope:
325
326
327
    with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):

328
329
      net, end_points = inception_resnet_v2_base(inputs, scope=scope,
                                                 activation_fn=activation_fn)
330

331
      if create_aux_logits and num_classes:
332
        with tf.compat.v1.variable_scope('AuxLogits'):
333
334
          aux = end_points['PreAuxLogits']
          aux = slim.avg_pool2d(aux, 5, stride=3, padding='VALID',
335
336
337
338
339
340
341
342
343
                                scope='Conv2d_1a_3x3')
          aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
          aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
                            padding='VALID', scope='Conv2d_2a_5x5')
          aux = slim.flatten(aux)
          aux = slim.fully_connected(aux, num_classes, activation_fn=None,
                                     scope='Logits')
          end_points['AuxLogits'] = aux

344
      with tf.compat.v1.variable_scope('Logits'):
345
346
347
348
349
350
351
        # TODO(sguada,arnoegw): Consider adding a parameter global_pool which
        # can be set to False to disable pooling here (as in resnet_*()).
        kernel_size = net.get_shape()[1:3]
        if kernel_size.is_fully_defined():
          net = slim.avg_pool2d(net, kernel_size, padding='VALID',
                                scope='AvgPool_1a_8x8')
        else:
352
353
          net = tf.reduce_mean(
              input_tensor=net, axis=[1, 2], keepdims=True, name='global_pool')
354
355
356
        end_points['global_pool'] = net
        if not num_classes:
          return net, end_points
357
358
359
360
361
362
363
364
        net = slim.flatten(net)
        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                           scope='Dropout')
        end_points['PreLogitsFlatten'] = net
        logits = slim.fully_connected(net, num_classes, activation_fn=None,
                                      scope='Logits')
        end_points['Logits'] = logits
        end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
365
366
367
368
369

    return logits, end_points
inception_resnet_v2.default_image_size = 299


370
371
372
373
374
def inception_resnet_v2_arg_scope(
    weight_decay=0.00004,
    batch_norm_decay=0.9997,
    batch_norm_epsilon=0.001,
    activation_fn=tf.nn.relu,
375
    batch_norm_updates_collections=tf.compat.v1.GraphKeys.UPDATE_OPS,
376
    batch_norm_scale=False):
Derek Chow's avatar
Derek Chow committed
377
  """Returns the scope with the default parameters for inception_resnet_v2.
378
379
380
381
382

  Args:
    weight_decay: the weight decay for weights variables.
    batch_norm_decay: decay for the moving average of batch_norm momentums.
    batch_norm_epsilon: small float added to variance to avoid dividing by zero.
383
    activation_fn: Activation function for conv2d.
384
385
    batch_norm_updates_collections: Collection for the update ops for
      batch norm.
386
387
    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
      activations in the batch normalization layer.
388
389
390
391
392
393
394
395
396
397
398
399

  Returns:
    a arg_scope with the parameters needed for inception_resnet_v2.
  """
  # Set weight_decay for weights in conv2d and fully_connected layers.
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      weights_regularizer=slim.l2_regularizer(weight_decay),
                      biases_regularizer=slim.l2_regularizer(weight_decay)):

    batch_norm_params = {
        'decay': batch_norm_decay,
        'epsilon': batch_norm_epsilon,
400
        'updates_collections': batch_norm_updates_collections,
401
        'fused': None,  # Use fused batch norm if possible.
402
        'scale': batch_norm_scale,
403
404
    }
    # Set activation_fn and parameters for batch_norm.
405
    with slim.arg_scope([slim.conv2d], activation_fn=activation_fn,
406
407
408
                        normalizer_fn=slim.batch_norm,
                        normalizer_params=batch_norm_params) as scope:
      return scope