mobilenet.py 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions of Mobilenet Networks."""

from typing import Text, Optional, Dict

# Import libraries
import tensorflow as tf
Shixin Luo's avatar
Shixin Luo committed
21
from official.vision.beta.modeling.backbones import factory
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers

layers = tf.keras.layers
regularizers = tf.keras.regularizers


class GlobalPoolingBlock(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(GlobalPoolingBlock, self).__init__(**kwargs)

  def call(self, inputs, training=None):
    x = layers.GlobalAveragePooling2D()(inputs)
    outputs = layers.Reshape((1, 1, x.shape[1]))(x)
    return outputs


"""
Architecture: https://arxiv.org/abs/1704.04861.

"MobileNets: Efficient Convolutional Neural Networks for
  Mobile Vision Applications"
Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang,
  Tobias Weyand, Marco Andreetto, Hartwig Adam
"""
MNV1_BLOCK_SPECS = {
    'spec_name': 'MobileNetV1',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters'],
    'block_specs': [
        ('convbn', 3, 2, 32),
        ('depsepconv', 3, 1, 64),
        ('depsepconv', 3, 2, 128),
        ('depsepconv', 3, 1, 128),
        ('depsepconv', 3, 2, 256),
        ('depsepconv', 3, 1, 256),
        ('depsepconv', 3, 2, 512),
        ('depsepconv', 3, 1, 512),
        ('depsepconv', 3, 1, 512),
        ('depsepconv', 3, 1, 512),
        ('depsepconv', 3, 1, 512),
        ('depsepconv', 3, 1, 512),
        ('depsepconv', 3, 2, 1024),
        ('depsepconv', 3, 1, 1024),
    ]
}

"""
Architecture: https://arxiv.org/abs/1801.04381

"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
"""
MNV2_BLOCK_SPECS = {
    'spec_name': 'MobileNetV2',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'expand_ratio'],
    'block_specs': [
        ('convbn', 3, 2, 32, None),

        ('mbconv', 3, 1, 16, 1.),

        ('mbconv', 3, 2, 24, 6.),
        ('mbconv', 3, 1, 24, 6.),

        ('mbconv', 3, 2, 32, 6.),
        ('mbconv', 3, 1, 32, 6.),
        ('mbconv', 3, 1, 32, 6.),

        ('mbconv', 3, 2, 64, 6.),
        ('mbconv', 3, 1, 64, 6.),
        ('mbconv', 3, 1, 64, 6.),
        ('mbconv', 3, 1, 64, 6.),

        ('mbconv', 3, 1, 96, 6.),
        ('mbconv', 3, 1, 96, 6.),
        ('mbconv', 3, 1, 96, 6.),

        ('mbconv', 3, 2, 160, 6.),
        ('mbconv', 3, 1, 160, 6.),
        ('mbconv', 3, 1, 160, 6.),

        ('mbconv', 3, 1, 320, 6.),

        ('convbn', 1, 2, 1280, None),
    ]
}

"""
Architecture: https://arxiv.org/abs/1905.02244

"Searching for MobileNetV3"
Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, 
Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam
"""
MNV3Large_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3Large',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
                          'use_normalization', 'use_bias'],
    'block_specs': [
        ('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),

        ('mbconv', 3, 1, 16, 'relu', None, 1., None, False),

        ('mbconv', 3, 2, 24, 'relu', None, 4., None, False),
        ('mbconv', 3, 1, 24, 'relu', None, 3., None, False),

        ('mbconv', 5, 2, 40, 'relu', 1. / 4, 3., None, False),
        ('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False),
        ('mbconv', 5, 1, 40, 'relu', 1. / 4, 3., None, False),

        ('mbconv', 3, 2, 80, 'hard_swish', None, 6., None, False),
        ('mbconv', 3, 1, 80, 'hard_swish', None, 2.5, None, False),
        ('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False),
        ('mbconv', 3, 1, 80, 'hard_swish', None, 2.3, None, False),

        ('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),
        ('mbconv', 3, 1, 112, 'hard_swish', 1. / 4, 6., None, False),

        ('mbconv', 5, 2, 160, 'hard_swish', 1. / 4, 6, None, False),
        ('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),
        ('mbconv', 5, 1, 160, 'hard_swish', 1. / 4, 6, None, False),

        ('convbn', 1, 1, 960, 'hard_swish', None, None, True, False),
        ('gpooling', None, None, None, None, None, None, None, None),
        ('convbn', 1, 1, 1280, 'hard_swish', None, None, False, True),
    ]
}

MNV3Small_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3Small',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
                          'use_normalization', 'use_bias'],
    'block_specs': [
        ('convbn', 3, 2, 16, 'hard_swish', None, None, True, False),

        ('mbconv', 3, 2, 16, 'relu', 1. / 4, 1, None, False),

        ('mbconv', 3, 2, 24, 'relu', None, 72. / 16, None, False),
        ('mbconv', 3, 1, 24, 'relu', None, 88. / 24, None, False),

        ('mbconv', 5, 2, 40, 'hard_swish', 1. / 4, 4., None, False),
        ('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),
        ('mbconv', 5, 1, 40, 'hard_swish', 1. / 4, 6., None, False),

        ('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),
        ('mbconv', 5, 1, 48, 'hard_swish', 1. / 4, 3., None, False),

        ('mbconv', 5, 2, 96, 'hard_swish', 1. / 4, 6., None, False),
        ('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),
        ('mbconv', 5, 1, 96, 'hard_swish', 1. / 4, 6., None, False),

        ('convbn', 1, 1, 576, 'hard_swish', None, None, True, False),
        ('gpooling', None, None, None, None, None, None, None, None),
        ('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True),
    ]
}

"""
The EdgeTPU version is taken from
github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
"""
MNV3EdgeTPU_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3EdgeTPU',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
                          'use_residual', 'use_depthwise'],
    'block_specs': [
        ('convbn', 3, 2, 32, 'relu', None, None, None, None),

        ('mbconv', 3, 1, 16, 'relu', None, 1., True, False),

        ('mbconv', 3, 2, 32, 'relu', None, 8., True, False),
        ('mbconv', 3, 1, 32, 'relu', None, 4., True, False),
        ('mbconv', 3, 1, 32, 'relu', None, 4., True, False),
        ('mbconv', 3, 1, 32, 'relu', None, 4., True, False),

        ('mbconv', 3, 2, 48, 'relu', None, 8., True, False),
        ('mbconv', 3, 1, 48, 'relu', None, 4., True, False),
        ('mbconv', 3, 1, 48, 'relu', None, 4., True, False),
        ('mbconv', 3, 1, 48, 'relu', None, 4., True, False),

        ('mbconv', 3, 2, 96, 'relu', None, 8., True, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),

        ('mbconv', 3, 1, 96, 'relu', None, 8., False, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),
        ('mbconv', 3, 1, 96, 'relu', None, 4., True, True),

        ('mbconv', 5, 2, 160, 'relu', None, 8., True, True),
        ('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
        ('mbconv', 5, 1, 160, 'relu', None, 4., True, True),
        ('mbconv', 5, 1, 160, 'relu', None, 4., True, True),

        ('mbconv', 3, 1, 192, 'relu', None, 8., True, True),

        ('convbn', 1, 1, 1280, 'relu', None, None, None, None),
    ]
}

SUPPORTED_SPECS_MAP = {
    'MobileNetV1': MNV1_BLOCK_SPECS,
    'MobileNetV2': MNV2_BLOCK_SPECS,
    'MobileNetV3Large': MNV3Large_BLOCK_SPECS,
    'MobileNetV3Small': MNV3Small_BLOCK_SPECS,
    'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS,
}

BLOCK_FN_MAP = {
    'convbn': nn_blocks.Conv2DBNBlock,
    'depsepconv': nn_blocks.DepthwiseSeparableConvBlock,
    'mbconv': nn_blocks.InvertedBottleneckBlock,
    'gpooling': GlobalPoolingBlock,
}


class BlockSpec(object):
  """A container class that specifies the block configuration for MobileNet."""

  def __init__(self,
               block_fn: Text = 'convbn',
               kernel_size: int = 3,
               strides: int = 1,
               filters: int = 32,
               use_bias: bool = False,
               use_normalization: bool = True,
               activation: Text = 'relu6',
               # used for block type InvertedResConv
               expand_ratio: Optional[float] = 6.,
               # used for block type InvertedResConv with SE
               se_ratio: Optional[float] = None,
               use_depthwise: bool = True,
               use_residual: bool = True, ):
    self.block_fn = block_fn
    self.kernel_size = kernel_size
    self.strides = strides
    self.filters = filters
    self.use_bias = use_bias
    self.use_normalization = use_normalization
    self.activation = activation
    self.expand_ratio = expand_ratio
    self.se_ratio = se_ratio
    self.use_depthwise = use_depthwise
    self.use_residual = use_residual


def block_spec_decoder(specs: Dict,
                       width_multiplier: float,
                       # set to 1 for mobilenetv1
                       divisible_by: int = 8,
                       finegrain_classification_mode: bool = True):
  """Decode specs for a block.

  Args:
    specs: `dict` specification of block specs of a mobilenet version.
    width_multiplier: `float` multiplier for the depth (number of channels)
      for all convolution ops. The value must be greater than zero. Typical
      usage will be to set this value in (0, 1) to reduce the number of
      parameters or computation cost of the model.
    divisible_by: `int` ensures all inner dimensions are divisible by
      this number.
    finegrain_classification_mode: if True, the model
      will keep the last layer large even for small multipliers. Following
      https://arxiv.org/abs/1801.04381

  Returns:
    List[BlockSpec]` defines structure of the base network.
  """

  spec_name = specs['spec_name']
  block_spec_schema = specs['block_spec_schema']
  block_specs = specs['block_specs']

  if len(block_specs) == 0:
    raise ValueError('The block spec cannot be empty for {} !'.format(spec_name))

  if len(block_specs[0]) != len(block_spec_schema):
    raise ValueError('The block spec values {} do not match with '
                     'the schema {}'.format(block_specs[0], block_spec_schema))

  decoded_specs = []

  for s in block_specs:
    kw_s = dict(zip(block_spec_schema, s))
    decoded_specs.append(BlockSpec(**kw_s))

  # This adjustment applies to V2 and V3
  if (spec_name != 'MobileNetV1'
      and finegrain_classification_mode
      and width_multiplier < 1.0):
    decoded_specs[-1].filters /= width_multiplier

  for ds in decoded_specs:
    if ds.filters:
      ds.filters = nn_layers.round_filters(filters=ds.filters,
                                           multiplier=width_multiplier,
                                           divisor=divisible_by,
                                           min_depth=8)

  return decoded_specs


@tf.keras.utils.register_keras_serializable(package='Vision')
class MobileNet(tf.keras.Model):
  def __init__(self,
               model_id: Text = 'MobileNetV2',
               width_multiplier: float = 1.0,
               input_specs: layers.InputSpec = layers.InputSpec(
                   shape=[None, None, None, 3]),
               # The followings are for hyper-parameter tuning
               norm_momentum: float = 0.99,
               norm_epsilon: float = 0.001,
               kernel_initializer: Text = 'VarianceScaling',
               kernel_regularizer: Optional[regularizers.Regularizer] = None,
               bias_regularizer: Optional[regularizers.Regularizer] = None,
               # The followings should be kept the same most of the times
               output_stride: int = None,
               min_depth: int = 8,
               # divisible is not used in MobileNetV1
               divisible_by: int = 8,
               stochastic_depth_drop_rate: float = 0.0,
               regularize_depthwise: bool = False,
               use_sync_bn: bool = False,
               # finegrain is not used in MobileNetV1
               finegrain_classification_mode: bool = True,
               **kwargs):
    """

    Args:
      model_id: `str` version of MobileNet. The supported values
        are 'MobileNetV1', 'MobileNetV2', 'MobileNetV3Large', 'MobileNetV3Small',
        and 'MobileNetV3EdgeTPU'.
      width_multiplier: `float` multiplier for the depth (number of channels)
        for all convolution ops. The value must be greater than zero. Typical
        usage will be to set this value in (0, 1) to reduce the number of
        parameters or computation cost of the model.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      norm_momentum: `float` normalization omentum for the moving average.
      norm_epsilon: `float` small float added to variance to avoid dividing by
        zero.
      kernel_initializer: `str` kernel_initializer for convolutional layers.
      kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
        Default to None.
      bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
        Default to None.
      output_stride: `int` specifies the requested ratio of input to
        output spatial resolution. If not None, then we invoke atrous convolution
        if necessary to prevent the network from reducing the spatial resolution
        of activation maps. Allowed values are 8 (accurate fully convolutional
        mode), 16 (fast fully convolutional mode), 32 (classification mode).
      min_depth: `int` minimum depth (number of channels) for all conv ops.
        Enforced when width_multiplier < 1, and not an active constraint when
        width_multiplier >= 1.
      divisible_by: `int` ensures all inner dimensions are divisible by
        this number.
      stochastic_depth_drop_rate: `float` drop rate for drop connect layer.
      regularize_depthwise: if Ture, apply regularization on depthwise.
      use_sync_bn: if True, use synchronized batch normalization.
      finegrain_classification_mode: if True, the model
        will keep the last layer large even for small multipliers. Following
        https://arxiv.org/abs/1801.04381
      **kwargs: keyword arguments to be passed.
    """
    if model_id not in SUPPORTED_SPECS_MAP:
      raise ValueError('The MobileNet version {} '
                       'is not supported'.format(model_id))

    if width_multiplier <= 0:
      raise ValueError('depth_multiplier is not greater than zero.')

    if output_stride is not None:
      if model_id == 'MobileNetV1':
        if output_stride not in [8, 16, 32]:
          raise ValueError('Only allowed output_stride values are 8, 16, 32.')
      else:
        if output_stride == 0 or (output_stride > 1 and output_stride % 2):
          raise ValueError('Output stride must be None, 1 or a multiple of 2.')

    self._model_id = model_id
    self._input_specs = input_specs
    self._width_multiplier = width_multiplier
    self._min_depth = min_depth
    self._output_stride = output_stride
    self._divisible_by = divisible_by
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
    self._regularize_depthwise = regularize_depthwise
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._use_sync_bn = use_sync_bn
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._finegrain_classification_mode = finegrain_classification_mode

    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    block_specs = SUPPORTED_SPECS_MAP.get(model_id)
    self._decoded_specs = block_spec_decoder(
        specs=block_specs,
        width_multiplier=self._width_multiplier,
        divisible_by=self._get_divisible_by(),
        finegrain_classification_mode=self._finegrain_classification_mode)

    x, endpoints = self._mobilenet_base(inputs=inputs)

    endpoints[max(endpoints.keys()) + 1] = x
    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(MobileNet, self).__init__(
        inputs=inputs, outputs=endpoints, **kwargs)

  def _get_divisible_by(self):
    if self._model_id == 'MobileNetV1':
      return 1
    else:
      return self._divisible_by

  def _mobilenet_base(self,
                      inputs: tf.Tensor
                      ) -> (tf.Tensor, Dict[int, tf.Tensor]):
    """Build the base MobileNet architecture.

    Args:
      inputs: Input tensor of shape [batch_size, height, width, channels].

    Returns:
      A tuple of output Tensor and dictionary that collects endpoints.
    """

    input_shape = inputs.get_shape().as_list()
    if len(input_shape) != 4:
      raise ValueError('Expected rank 4 input, was: %d' % len(input_shape))

    # The current_stride variable keeps track of the output stride of the
    # activations, i.e., the running product of convolution strides up to the
    # current network layer. This allows us to invoke atrous convolution
    # whenever applying the next convolution would result in the activations
    # having output stride larger than the target output_stride.
    current_stride = 1

    # The atrous convolution rate parameter.
    rate = 1

    net = inputs
    endpoints = {}
    endpoint_level = 1
    for i, block_def in enumerate(self._decoded_specs):
      block_name = 'block_group_{}_{}'.format(block_def.block_fn, i)
      # A small catch for gpooling block with None strides
      if not block_def.strides:
        block_def.strides = 1
      if self._output_stride is not None \
          and current_stride == self._output_stride:
        # If we have reached the target output_stride, then we need to employ
        # atrous convolution with stride=1 and multiply the atrous rate by the
        # current unit's stride for use in subsequent layers.
        layer_stride = 1
        layer_rate = rate
        rate *= block_def.strides
      else:
        layer_stride = block_def.strides
        layer_rate = 1
        current_stride *= block_def.strides

      if block_def.block_fn == 'convbn':

        net = nn_blocks.Conv2DBNBlock(
            filters=block_def.filters,
            kernel_size=block_def.kernel_size,
            strides=block_def.strides,
            activation=block_def.activation,
            use_bias=block_def.use_bias,
            use_normalization=block_def.use_normalization,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon
        )(net)

      elif block_def.block_fn == 'depsepconv':
        net = nn_blocks.DepthwiseSeparableConvBlock(
            filters=block_def.filters,
            kernel_size=block_def.kernel_size,
            strides=block_def.strides,
            activation=block_def.activation,
            dilation_rate=layer_rate,
            regularize_depthwise=self._regularize_depthwise,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon,
        )(net)

      elif block_def.block_fn == 'mbconv':
        use_rate = rate
        if layer_rate > 1 and block_def.kernel_size != 1:
          # We will apply atrous rate in the following cases:
          # 1) When kernel_size is not in params, the operation then uses
          #   default kernel size 3x3.
          # 2) When kernel_size is in params, and if the kernel_size is not
          #   equal to (1, 1) (there is no need to apply atrous convolution to
          #   any 1x1 convolution).
          use_rate = layer_rate
        in_filters = net.shape.as_list()[-1]
        net = nn_blocks.InvertedBottleneckBlock(
            in_filters=in_filters,
            out_filters=block_def.filters,
            kernel_size=block_def.kernel_size,
            strides=layer_stride,
            expand_ratio=block_def.expand_ratio,
            se_ratio=block_def.se_ratio,
            activation=block_def.activation,
            use_depthwise=block_def.use_depthwise,
            use_residual=block_def.use_residual,
            dilation_rate=use_rate,
            regularize_depthwise=self._regularize_depthwise,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon,
            stochastic_depth_drop_rate=self._stochastic_depth_drop_rate,
            divisible_by=self._get_divisible_by(),
            target_backbone='mobilenet'
        )(net)

      elif block_def.block_fn == 'gpooling':
        net = GlobalPoolingBlock()(net)

      else:
        raise ValueError('Unknown block type {} for layer {}'.format(
            block_def.block_fn, i))

      endpoints[endpoint_level] = net
      endpoint_level += 1
      net = tf.identity(net, name=block_name)
    return net, endpoints

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
        'width_multiplier': self._width_multiplier,
        'min_depth': self._min_depth,
        'output_stride': self._output_stride,
        'divisible_by': self._divisible_by,
        'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
        'regularize_depthwise': self._regularize_depthwise,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'finegrain_classification_mode': self._finegrain_classification_mode,
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs
Shixin Luo's avatar
Shixin Luo committed
595

596

Shixin Luo's avatar
Shixin Luo committed
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
@factory.register_backbone_builder('mobilenet')
def build_mobilenet(
    input_specs: tf.keras.layers.InputSpec,
    model_config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
  """Builds MobileNet 3d backbone from a config."""
  backbone_type = model_config.backbone.type
  backbone_cfg = model_config.backbone.get()
  norm_activation_config = model_config.norm_activation
  assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
                                           f'{backbone_type}')

  return MobileNet(
      model_id=backbone_cfg.model_id,
      width_multiplier=backbone_cfg.width_multiplier,
      input_specs=input_specs,
      stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)