mobilenet.py 36.4 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

Fan Yang's avatar
Fan Yang committed
15
"""Contains definitions of MobileNet Networks."""
16

Yuqi Li's avatar
Yuqi Li committed
17
import dataclasses
Fan Yang's avatar
Fan Yang committed
18
from typing import Optional, Dict, Any, Tuple
19
20
21

# Import libraries
import tensorflow as tf
22
from official.modeling import hyperparams
23
from official.modeling import tf_utils
Shixin Luo's avatar
Shixin Luo committed
24
from official.vision.beta.modeling.backbones import factory
25
26
27
28
29
from official.vision.beta.modeling.layers import nn_blocks
from official.vision.beta.modeling.layers import nn_layers

layers = tf.keras.layers

30

31
32
33
#  pylint: disable=pointless-string-statement


34
@tf.keras.utils.register_keras_serializable(package='Vision')
35
36
class Conv2DBNBlock(tf.keras.layers.Layer):
  """A convolution block with batch normalization."""
37

38
39
40
41
42
43
  def __init__(
      self,
      filters: int,
      kernel_size: int = 3,
      strides: int = 1,
      use_bias: bool = False,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
44
      use_explicit_padding: bool = False,
Fan Yang's avatar
Fan Yang committed
45
46
      activation: str = 'relu6',
      kernel_initializer: str = 'VarianceScaling',
47
48
49
50
51
52
53
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      use_normalization: bool = True,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      **kwargs):
54
    """A convolution block with batch normalization.
55

56
    Args:
Fan Yang's avatar
Fan Yang committed
57
58
59
60
61
62
63
      filters: An `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
      kernel_size: An `int` specifying the height and width of the 2D
        convolution window.
      strides: An `int` of block stride. If greater than 1, this block will
        ultimately downsample the input.
      use_bias: If True, use bias in the convolution layer.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
64
65
66
      use_explicit_padding: Use 'VALID' padding for convolutions, but prepad
        inputs so that the output dimensions are the same as if 'SAME' padding
        were used.
Fan Yang's avatar
Fan Yang committed
67
68
69
70
71
72
73
74
75
76
77
78
      activation: A `str` name of the activation function.
      kernel_initializer: A `str` for kernel initializer of convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      use_normalization: If True, use batch normalization.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      **kwargs: Additional keyword arguments to be passed.
79
80
81
82
83
84
85
    """
    super(Conv2DBNBlock, self).__init__(**kwargs)
    self._filters = filters
    self._kernel_size = kernel_size
    self._strides = strides
    self._activation = activation
    self._use_bias = use_bias
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
86
    self._use_explicit_padding = use_explicit_padding
87
88
89
90
91
92
93
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._use_normalization = use_normalization
    self._use_sync_bn = use_sync_bn
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
94

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
96
97
98
    if use_explicit_padding and kernel_size > 1:
      self._padding = 'valid'
    else:
      self._padding = 'same'
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    if use_sync_bn:
      self._norm = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      self._norm = tf.keras.layers.BatchNormalization
    if tf.keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

  def get_config(self):
    config = {
        'filters': self._filters,
        'strides': self._strides,
        'kernel_size': self._kernel_size,
        'use_bias': self._use_bias,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
114
        'use_explicit_padding': self._use_explicit_padding,
115
116
117
118
119
120
121
122
123
124
125
126
127
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'use_normalization': self._use_normalization,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon
    }
    base_config = super(Conv2DBNBlock, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def build(self, input_shape):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
128
129
130
    if self._use_explicit_padding and self._kernel_size > 1:
      padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
      self._pad = tf.keras.layers.ZeroPadding2D(padding_size)
131
132
133
134
    self._conv0 = tf.keras.layers.Conv2D(
        filters=self._filters,
        kernel_size=self._kernel_size,
        strides=self._strides,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
        padding=self._padding,
136
137
138
139
140
141
142
143
144
        use_bias=self._use_bias,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer)
    if self._use_normalization:
      self._norm0 = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon)
145
146
    self._activation_layer = tf_utils.get_activation(
        self._activation, use_keras_layer=True)
147
148
149
150

    super(Conv2DBNBlock, self).build(input_shape)

  def call(self, inputs, training=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
152
    if self._use_explicit_padding and self._kernel_size > 1:
      inputs = self._pad(inputs)
153
154
155
    x = self._conv0(inputs)
    if self._use_normalization:
      x = self._norm0(x)
156
    return self._activation_layer(x)
157
158
159
160

"""
Architecture: https://arxiv.org/abs/1704.04861.

161
162
163
"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision
Applications" Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko,
Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
164
165
166
"""
MNV1_BLOCK_SPECS = {
    'spec_name': 'MobileNetV1',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
167
168
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides',
                          'filters', 'is_output'],
169
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        ('convbn', 3, 2, 32, False),
        ('depsepconv', 3, 1, 64, False),
        ('depsepconv', 3, 2, 128, False),
        ('depsepconv', 3, 1, 128, True),
        ('depsepconv', 3, 2, 256, False),
        ('depsepconv', 3, 1, 256, True),
        ('depsepconv', 3, 2, 512, False),
        ('depsepconv', 3, 1, 512, False),
        ('depsepconv', 3, 1, 512, False),
        ('depsepconv', 3, 1, 512, False),
        ('depsepconv', 3, 1, 512, False),
        ('depsepconv', 3, 1, 512, True),
        ('depsepconv', 3, 2, 1024, False),
        ('depsepconv', 3, 1, 1024, True),
184
185
186
187
188
189
190
191
192
193
194
195
    ]
}

"""
Architecture: https://arxiv.org/abs/1801.04381

"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
"""
MNV2_BLOCK_SPECS = {
    'spec_name': 'MobileNetV2',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
196
                          'expand_ratio', 'is_output'],
197
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        ('convbn', 3, 2, 32, None, False),
        ('invertedbottleneck', 3, 1, 16, 1., False),
        ('invertedbottleneck', 3, 2, 24, 6., False),
        ('invertedbottleneck', 3, 1, 24, 6., True),
        ('invertedbottleneck', 3, 2, 32, 6., False),
        ('invertedbottleneck', 3, 1, 32, 6., False),
        ('invertedbottleneck', 3, 1, 32, 6., True),
        ('invertedbottleneck', 3, 2, 64, 6., False),
        ('invertedbottleneck', 3, 1, 64, 6., False),
        ('invertedbottleneck', 3, 1, 64, 6., False),
        ('invertedbottleneck', 3, 1, 64, 6., False),
        ('invertedbottleneck', 3, 1, 96, 6., False),
        ('invertedbottleneck', 3, 1, 96, 6., False),
        ('invertedbottleneck', 3, 1, 96, 6., True),
        ('invertedbottleneck', 3, 2, 160, 6., False),
        ('invertedbottleneck', 3, 1, 160, 6., False),
        ('invertedbottleneck', 3, 1, 160, 6., False),
        ('invertedbottleneck', 3, 1, 320, 6., True),
        ('convbn', 1, 1, 1280, None, False),
217
218
219
220
221
222
223
    ]
}

"""
Architecture: https://arxiv.org/abs/1905.02244

"Searching for MobileNetV3"
224
Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan,
225
226
227
228
229
230
Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam
"""
MNV3Large_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3Large',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
231
                          'use_normalization', 'use_bias', 'is_output'],
232
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        ('convbn', 3, 2, 16,
         'hard_swish', None, None, True, False, False),
        ('invertedbottleneck', 3, 1, 16,
         'relu', None, 1., None, False, False),
        ('invertedbottleneck', 3, 2, 24,
         'relu', None, 4., None, False, False),
        ('invertedbottleneck', 3, 1, 24,
         'relu', None, 3., None, False, True),
        ('invertedbottleneck', 5, 2, 40,
         'relu', 0.25, 3., None, False, False),
        ('invertedbottleneck', 5, 1, 40,
         'relu', 0.25, 3., None, False, False),
        ('invertedbottleneck', 5, 1, 40,
         'relu', 0.25, 3., None, False, True),
        ('invertedbottleneck', 3, 2, 80,
         'hard_swish', None, 6., None, False, False),
        ('invertedbottleneck', 3, 1, 80,
         'hard_swish', None, 2.5, None, False, False),
        ('invertedbottleneck', 3, 1, 80,
         'hard_swish', None, 2.3, None, False, False),
        ('invertedbottleneck', 3, 1, 80,
         'hard_swish', None, 2.3, None, False, False),
        ('invertedbottleneck', 3, 1, 112,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 3, 1, 112,
         'hard_swish', 0.25, 6., None, False, True),
        ('invertedbottleneck', 5, 2, 160,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 160,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 160,
         'hard_swish', 0.25, 6., None, False, True),
        ('convbn', 1, 1, 960,
         'hard_swish', None, None, True, False, False),
        ('gpooling', None, None, None,
         None, None, None, None, None, False),
        ('convbn', 1, 1, 1280,
         'hard_swish', None, None, False, True, False),
271
272
273
274
275
276
277
    ]
}

MNV3Small_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3Small',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
278
                          'use_normalization', 'use_bias', 'is_output'],
279
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        ('convbn', 3, 2, 16,
         'hard_swish', None, None, True, False, False),
        ('invertedbottleneck', 3, 2, 16,
         'relu', 0.25, 1, None, False, True),
        ('invertedbottleneck', 3, 2, 24,
         'relu', None, 72. / 16, None, False, False),
        ('invertedbottleneck', 3, 1, 24,
         'relu', None, 88. / 24, None, False, True),
        ('invertedbottleneck', 5, 2, 40,
         'hard_swish', 0.25, 4., None, False, False),
        ('invertedbottleneck', 5, 1, 40,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 40,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 48,
         'hard_swish', 0.25, 3., None, False, False),
        ('invertedbottleneck', 5, 1, 48,
         'hard_swish', 0.25, 3., None, False, True),
        ('invertedbottleneck', 5, 2, 96,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 96,
         'hard_swish', 0.25, 6., None, False, False),
        ('invertedbottleneck', 5, 1, 96,
         'hard_swish', 0.25, 6., None, False, True),
        ('convbn', 1, 1, 576,
         'hard_swish', None, None, True, False, False),
        ('gpooling', None, None, None,
         None, None, None, None, None, False),
        ('convbn', 1, 1, 1024,
         'hard_swish', None, None, False, True, False),
310
311
312
313
314
315
316
317
318
319
320
    ]
}

"""
The EdgeTPU version is taken from
github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
"""
MNV3EdgeTPU_BLOCK_SPECS = {
    'spec_name': 'MobileNetV3EdgeTPU',
    'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
                          'activation', 'se_ratio', 'expand_ratio',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
321
                          'use_residual', 'use_depthwise', 'is_output'],
322
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        ('convbn', 3, 2, 32, 'relu', None, None, None, None, False),
        ('invertedbottleneck', 3, 1, 16, 'relu', None, 1., True, False, False),
        ('invertedbottleneck', 3, 2, 32, 'relu', None, 8., True, False, False),
        ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, False),
        ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, False),
        ('invertedbottleneck', 3, 1, 32, 'relu', None, 4., True, False, True),
        ('invertedbottleneck', 3, 2, 48, 'relu', None, 8., True, False, False),
        ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, False),
        ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, False),
        ('invertedbottleneck', 3, 1, 48, 'relu', None, 4., True, False, True),
        ('invertedbottleneck', 3, 2, 96, 'relu', None, 8., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 8., False, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 96, 'relu', None, 4., True, True, True),
        ('invertedbottleneck', 5, 2, 160, 'relu', None, 8., True, True, False),
        ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 5, 1, 160, 'relu', None, 4., True, True, False),
        ('invertedbottleneck', 3, 1, 192, 'relu', None, 8., True, True, True),
        ('convbn', 1, 1, 1280, 'relu', None, None, None, None, False),
347
348
349
    ]
}

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
350
351
352
353
354
355
356
357
358
359
"""
Architecture: https://arxiv.org/pdf/2008.08178.pdf

"Discovering Multi-Hardware Mobile Models via Architecture Search"
Grace Chu, Okan Arikan, Gabriel Bender, Weijun Wang,
Achille Brighton, Pieter-Jan Kindermans, Hanxiao Liu,
Berkin Akin, Suyog Gupta, and Andrew Howard
"""
MNMultiMAX_BLOCK_SPECS = {
    'spec_name': 'MobileNetMultiMAX',
Xianzhi Du's avatar
Xianzhi Du committed
360
361
362
363
    'block_spec_schema': [
        'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
        'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
    ],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
364
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        ('convbn', 3, 2, 32, 'relu', None, True, False, False),
        ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True),
        ('invertedbottleneck', 5, 2, 64, 'relu', 6., None, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, True),
        ('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 4., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 6., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, True),
        ('invertedbottleneck', 3, 2, 160, 'relu', 6., None, False, False),
        ('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, False),
        ('invertedbottleneck', 3, 1, 160, 'relu', 5., None, False, False),
        ('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True),
        ('convbn', 1, 1, 960, 'relu', None, True, False, False),
        ('gpooling', None, None, None, None, None, None, None, False),
Xianzhi Du's avatar
Xianzhi Du committed
382
383
384
        # Remove bias and add batch norm for the last layer to support QAT
        # and achieve slightly better accuracy.
        ('convbn', 1, 1, 1280, 'relu', None, True, False, False),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
385
386
387
388
389
    ]
}

MNMultiAVG_BLOCK_SPECS = {
    'spec_name': 'MobileNetMultiAVG',
Xianzhi Du's avatar
Xianzhi Du committed
390
391
392
393
    'block_spec_schema': [
        'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
        'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
    ],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
394
    'block_specs': [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        ('convbn', 3, 2, 32, 'relu', None, True, False, False),
        ('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 32, 'relu', 2., None, False, True),
        ('invertedbottleneck', 5, 2, 64, 'relu', 5., None, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 2., None, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 3., None, False, True),
        ('invertedbottleneck', 5, 2, 128, 'relu', 6., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., None, False, False),
        ('invertedbottleneck', 3, 1, 160, 'relu', 6., None, False, False),
        ('invertedbottleneck', 3, 1, 160, 'relu', 4., None, False, True),
        ('invertedbottleneck', 3, 2, 192, 'relu', 6., None, False, False),
        ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, False),
        ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, False),
        ('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True),
        ('convbn', 1, 1, 960, 'relu', None, True, False, False),
        ('gpooling', None, None, None, None, None, None, None, False),
Xianzhi Du's avatar
Xianzhi Du committed
414
415
416
        # Remove bias and add batch norm for the last layer to support QAT
        # and achieve slightly better accuracy.
        ('convbn', 1, 1, 1280, 'relu', None, True, False, False),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
417
418
419
    ]
}

Yuqi Li's avatar
Yuqi Li committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
# Similar to MobileNetMultiAVG and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block.
MNMultiAVG_SEG_BLOCK_SPECS = {
    'spec_name': 'MobileNetMultiAVGSeg',
    'block_spec_schema': [
        'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
        'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
    ],
    'block_specs': [
        ('convbn', 3, 2, 32, 'relu', None, True, False, False),
        ('invertedbottleneck', 3, 2, 32, 'relu', 3., True, False, False),
        ('invertedbottleneck', 3, 1, 32, 'relu', 2., True, False, True),
        ('invertedbottleneck', 5, 2, 64, 'relu', 5., True, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 3., True, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, False),
        ('invertedbottleneck', 3, 1, 64, 'relu', 3., True, False, True),
        ('invertedbottleneck', 5, 2, 128, 'relu', 6., True, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
        ('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
        ('invertedbottleneck', 3, 1, 160, 'relu', 6., True, False, False),
        ('invertedbottleneck', 3, 1, 160, 'relu', 4., True, False, True),
        ('invertedbottleneck', 3, 2, 192, 'relu', 6., True, False, False),
        ('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
        ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False),
        ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True),
        ('convbn', 1, 1, 480, 'relu', None, True, False, False),
        ('gpooling', None, None, None, None, None, None, None, False),
        # Remove bias and add batch norm for the last layer to support QAT
        # and achieve slightly better accuracy.
        ('convbn', 1, 1, 1280, 'relu', None, True, False, False),
    ]
}

454
455
456
457
458
459
SUPPORTED_SPECS_MAP = {
    'MobileNetV1': MNV1_BLOCK_SPECS,
    'MobileNetV2': MNV2_BLOCK_SPECS,
    'MobileNetV3Large': MNV3Large_BLOCK_SPECS,
    'MobileNetV3Small': MNV3Small_BLOCK_SPECS,
    'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
460
461
    'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS,
    'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
Yuqi Li's avatar
Yuqi Li committed
462
    'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
463
464
465
}


466
@dataclasses.dataclass
467
class BlockSpec(hyperparams.Config):
468
469
  """A container class that specifies the block configuration for MobileNet."""

Fan Yang's avatar
Fan Yang committed
470
  block_fn: str = 'convbn'
471
472
473
474
475
  kernel_size: int = 3
  strides: int = 1
  filters: int = 32
  use_bias: bool = False
  use_normalization: bool = True
Fan Yang's avatar
Fan Yang committed
476
  activation: str = 'relu6'
Fan Yang's avatar
Fan Yang committed
477
  # Used for block type InvertedResConv.
478
  expand_ratio: Optional[float] = 6.
Fan Yang's avatar
Fan Yang committed
479
  # Used for block type InvertedResConv with SE.
480
481
482
  se_ratio: Optional[float] = None
  use_depthwise: bool = True
  use_residual: bool = True
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
483
  is_output: bool = True
484
485


Fan Yang's avatar
Fan Yang committed
486
487
488
489
490
491
def block_spec_decoder(
    specs: Dict[Any, Any],
    filter_size_scale: float,
    # Set to 1 for mobilenetv1.
    divisible_by: int = 8,
    finegrain_classification_mode: bool = True):
Fan Yang's avatar
Fan Yang committed
492
  """Decodes specs for a block.
493
494

  Args:
Fan Yang's avatar
Fan Yang committed
495
496
497
498
499
500
    specs: A `dict` specification of block specs of a mobilenet version.
    filter_size_scale: A `float` multiplier for the filter size for all
      convolution ops. The value must be greater than zero. Typical usage will
      be to set this value in (0, 1) to reduce the number of parameters or
      computation cost of the model.
    divisible_by: An `int` that ensures all inner dimensions are divisible by
501
      this number.
Fan Yang's avatar
Fan Yang committed
502
503
504
    finegrain_classification_mode: If True, the model will keep the last layer
      large even for small multipliers, following
      https://arxiv.org/abs/1801.04381.
505
506

  Returns:
Fan Yang's avatar
Fan Yang committed
507
    A list of `BlockSpec` that defines structure of the base network.
508
509
510
511
512
513
  """

  spec_name = specs['spec_name']
  block_spec_schema = specs['block_spec_schema']
  block_specs = specs['block_specs']

514
  if not block_specs:
515
516
    raise ValueError(
        'The block spec cannot be empty for {} !'.format(spec_name))
517
518
519
520
521
522
523
524
525
526
527
528
529
530

  if len(block_specs[0]) != len(block_spec_schema):
    raise ValueError('The block spec values {} do not match with '
                     'the schema {}'.format(block_specs[0], block_spec_schema))

  decoded_specs = []

  for s in block_specs:
    kw_s = dict(zip(block_spec_schema, s))
    decoded_specs.append(BlockSpec(**kw_s))

  # This adjustment applies to V2 and V3
  if (spec_name != 'MobileNetV1'
      and finegrain_classification_mode
531
      and filter_size_scale < 1.0):
Rebecca Chen's avatar
Rebecca Chen committed
532
    decoded_specs[-1].filters /= filter_size_scale  # pytype: disable=annotation-type-mismatch
533
534
535
536

  for ds in decoded_specs:
    if ds.filters:
      ds.filters = nn_layers.round_filters(filters=ds.filters,
537
                                           multiplier=filter_size_scale,
538
539
540
541
542
543
544
545
                                           divisor=divisible_by,
                                           min_depth=8)

  return decoded_specs


@tf.keras.utils.register_keras_serializable(package='Vision')
class MobileNet(tf.keras.Model):
Fan Yang's avatar
Fan Yang committed
546
547
548
549
550
551
  """Creates a MobileNet family model."""

  def __init__(
      self,
      model_id: str = 'MobileNetV2',
      filter_size_scale: float = 1.0,
Fan Yang's avatar
Fan Yang committed
552
      input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
Fan Yang's avatar
Fan Yang committed
553
          shape=[None, None, None, 3]),
Fan Yang's avatar
Fan Yang committed
554
      # The followings are for hyper-parameter tuning.
Fan Yang's avatar
Fan Yang committed
555
556
557
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
Fan Yang's avatar
Fan Yang committed
558
559
560
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      # The followings should be kept the same most of the times.
Rebecca Chen's avatar
Rebecca Chen committed
561
      output_stride: Optional[int] = None,
Fan Yang's avatar
Fan Yang committed
562
      min_depth: int = 8,
Fan Yang's avatar
Fan Yang committed
563
      # divisible is not used in MobileNetV1.
Fan Yang's avatar
Fan Yang committed
564
565
566
567
      divisible_by: int = 8,
      stochastic_depth_drop_rate: float = 0.0,
      regularize_depthwise: bool = False,
      use_sync_bn: bool = False,
Fan Yang's avatar
Fan Yang committed
568
      # finegrain is not used in MobileNetV1.
Fan Yang's avatar
Fan Yang committed
569
      finegrain_classification_mode: bool = True,
Yuqi Li's avatar
Yuqi Li committed
570
      output_intermediate_endpoints: bool = False,
Fan Yang's avatar
Fan Yang committed
571
572
      **kwargs):
    """Initializes a MobileNet model.
573
574

    Args:
Fan Yang's avatar
Fan Yang committed
575
576
      model_id: A `str` of MobileNet version. The supported values are
        `MobileNetV1`, `MobileNetV2`, `MobileNetV3Large`, `MobileNetV3Small`,
577
        `MobileNetV3EdgeTPU`, `MobileNetMultiMAX` and `MobileNetMultiAVG`.
Fan Yang's avatar
Fan Yang committed
578
579
580
581
582
583
584
585
586
587
588
589
      filter_size_scale: A `float` of multiplier for the filters (number of
        channels) for all convolution ops. The value must be greater than zero.
        Typical usage will be to set this value in (0, 1) to reduce the number
        of parameters or computation cost of the model.
      input_specs: A `tf.keras.layers.InputSpec` of specs of the input tensor.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A `str` for kernel initializer of convolutional
        layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
590
        Default to None.
Fan Yang's avatar
Fan Yang committed
591
592
593
594
595
596
597
598
599
600
      output_stride: An `int` that specifies the requested ratio of input to
        output spatial resolution. If not None, then we invoke atrous
        convolution if necessary to prevent the network from reducing the
        spatial resolution of activation maps. Allowed values are 8 (accurate
        fully convolutional mode), 16 (fast fully convolutional mode), 32
        (classification mode).
      min_depth: An `int` of minimum depth (number of channels) for all
        convolution ops. Enforced when filter_size_scale < 1, and not an active
        constraint when filter_size_scale >= 1.
      divisible_by: An `int` that ensures all inner dimensions are divisible by
601
        this number.
Fan Yang's avatar
Fan Yang committed
602
603
604
605
606
607
      stochastic_depth_drop_rate: A `float` of drop rate for drop connect layer.
      regularize_depthwise: If Ture, apply regularization on depthwise.
      use_sync_bn: If True, use synchronized batch normalization.
      finegrain_classification_mode: If True, the model will keep the last layer
        large even for small multipliers, following
        https://arxiv.org/abs/1801.04381.
Yuqi Li's avatar
Yuqi Li committed
608
609
      output_intermediate_endpoints: A `bool` of whether or not output the
        intermediate endpoints.
Fan Yang's avatar
Fan Yang committed
610
      **kwargs: Additional keyword arguments to be passed.
611
612
613
614
615
    """
    if model_id not in SUPPORTED_SPECS_MAP:
      raise ValueError('The MobileNet version {} '
                       'is not supported'.format(model_id))

616
617
    if filter_size_scale <= 0:
      raise ValueError('filter_size_scale is not greater than zero.')
618
619
620
621
622
623
624
625
626
627
628

    if output_stride is not None:
      if model_id == 'MobileNetV1':
        if output_stride not in [8, 16, 32]:
          raise ValueError('Only allowed output_stride values are 8, 16, 32.')
      else:
        if output_stride == 0 or (output_stride > 1 and output_stride % 2):
          raise ValueError('Output stride must be None, 1 or a multiple of 2.')

    self._model_id = model_id
    self._input_specs = input_specs
629
    self._filter_size_scale = filter_size_scale
630
631
632
633
634
635
636
637
638
639
640
641
    self._min_depth = min_depth
    self._output_stride = output_stride
    self._divisible_by = divisible_by
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
    self._regularize_depthwise = regularize_depthwise
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    self._use_sync_bn = use_sync_bn
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._finegrain_classification_mode = finegrain_classification_mode
Yuqi Li's avatar
Yuqi Li committed
642
    self._output_intermediate_endpoints = output_intermediate_endpoints
643
644
645
646
647
648

    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    block_specs = SUPPORTED_SPECS_MAP.get(model_id)
    self._decoded_specs = block_spec_decoder(
        specs=block_specs,
649
        filter_size_scale=self._filter_size_scale,
650
651
652
        divisible_by=self._get_divisible_by(),
        finegrain_classification_mode=self._finegrain_classification_mode)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
653
    x, endpoints, next_endpoint_level = self._mobilenet_base(inputs=inputs)
654
655

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
Xianzhi Du's avatar
Xianzhi Du committed
656
657
    # Don't include the final layer in `self._output_specs` to support decoders.
    endpoints[str(next_endpoint_level)] = x
658
659
660
661
662
663
664
665
666
667
668
669

    super(MobileNet, self).__init__(
        inputs=inputs, outputs=endpoints, **kwargs)

  def _get_divisible_by(self):
    if self._model_id == 'MobileNetV1':
      return 1
    else:
      return self._divisible_by

  def _mobilenet_base(self,
                      inputs: tf.Tensor
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
670
                      ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], int]:
Fan Yang's avatar
Fan Yang committed
671
    """Builds the base MobileNet architecture.
672
673

    Args:
Fan Yang's avatar
Fan Yang committed
674
      inputs: A `tf.Tensor` of shape `[batch_size, height, width, channels]`.
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695

    Returns:
      A tuple of output Tensor and dictionary that collects endpoints.
    """

    input_shape = inputs.get_shape().as_list()
    if len(input_shape) != 4:
      raise ValueError('Expected rank 4 input, was: %d' % len(input_shape))

    # The current_stride variable keeps track of the output stride of the
    # activations, i.e., the running product of convolution strides up to the
    # current network layer. This allows us to invoke atrous convolution
    # whenever applying the next convolution would result in the activations
    # having output stride larger than the target output_stride.
    current_stride = 1

    # The atrous convolution rate parameter.
    rate = 1

    net = inputs
    endpoints = {}
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
696
    endpoint_level = 2
697
698
699
700
701
    for i, block_def in enumerate(self._decoded_specs):
      block_name = 'block_group_{}_{}'.format(block_def.block_fn, i)
      # A small catch for gpooling block with None strides
      if not block_def.strides:
        block_def.strides = 1
Fan Yang's avatar
Fan Yang committed
702
703
      if (self._output_stride is not None and
          current_stride == self._output_stride):
704
705
706
707
708
709
710
711
712
713
714
        # If we have reached the target output_stride, then we need to employ
        # atrous convolution with stride=1 and multiply the atrous rate by the
        # current unit's stride for use in subsequent layers.
        layer_stride = 1
        layer_rate = rate
        rate *= block_def.strides
      else:
        layer_stride = block_def.strides
        layer_rate = 1
        current_stride *= block_def.strides

Yuqi Li's avatar
Yuqi Li committed
715
      intermediate_endpoints = {}
716
717
      if block_def.block_fn == 'convbn':

718
        net = Conv2DBNBlock(
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
            filters=block_def.filters,
            kernel_size=block_def.kernel_size,
            strides=block_def.strides,
            activation=block_def.activation,
            use_bias=block_def.use_bias,
            use_normalization=block_def.use_normalization,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon
        )(net)

      elif block_def.block_fn == 'depsepconv':
        net = nn_blocks.DepthwiseSeparableConvBlock(
            filters=block_def.filters,
            kernel_size=block_def.kernel_size,
Yuqi Li's avatar
Yuqi Li committed
737
            strides=layer_stride,
738
739
740
741
742
743
744
745
746
747
            activation=block_def.activation,
            dilation_rate=layer_rate,
            regularize_depthwise=self._regularize_depthwise,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon,
        )(net)

748
      elif block_def.block_fn == 'invertedbottleneck':
749
750
751
752
753
754
755
756
757
758
        use_rate = rate
        if layer_rate > 1 and block_def.kernel_size != 1:
          # We will apply atrous rate in the following cases:
          # 1) When kernel_size is not in params, the operation then uses
          #   default kernel size 3x3.
          # 2) When kernel_size is in params, and if the kernel_size is not
          #   equal to (1, 1) (there is no need to apply atrous convolution to
          #   any 1x1 convolution).
          use_rate = layer_rate
        in_filters = net.shape.as_list()[-1]
Yuqi Li's avatar
Yuqi Li committed
759
        block = nn_blocks.InvertedBottleneckBlock(
760
761
762
763
764
765
            in_filters=in_filters,
            out_filters=block_def.filters,
            kernel_size=block_def.kernel_size,
            strides=layer_stride,
            expand_ratio=block_def.expand_ratio,
            se_ratio=block_def.se_ratio,
766
767
            expand_se_in_filters=True,
            se_gating_activation='hard_sigmoid',
768
769
770
771
772
773
774
775
776
777
778
779
            activation=block_def.activation,
            use_depthwise=block_def.use_depthwise,
            use_residual=block_def.use_residual,
            dilation_rate=use_rate,
            regularize_depthwise=self._regularize_depthwise,
            kernel_initializer=self._kernel_initializer,
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            use_sync_bn=self._use_sync_bn,
            norm_momentum=self._norm_momentum,
            norm_epsilon=self._norm_epsilon,
            stochastic_depth_drop_rate=self._stochastic_depth_drop_rate,
Yuqi Li's avatar
Yuqi Li committed
780
781
782
783
784
785
786
            divisible_by=self._get_divisible_by(),
            output_intermediate_endpoints=self._output_intermediate_endpoints,
        )
        if self._output_intermediate_endpoints:
          net, intermediate_endpoints = block(net)
        else:
          net = block(net)
787
788

      elif block_def.block_fn == 'gpooling':
789
790
        net = layers.GlobalAveragePooling2D()(net)
        net = layers.Reshape((1, 1, net.shape[1]))(net)
791
792
793
794
795

      else:
        raise ValueError('Unknown block type {} for layer {}'.format(
            block_def.block_fn, i))

796
      net = tf.keras.layers.Activation('linear', name=block_name)(net)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
797
798
799

      if block_def.is_output:
        endpoints[str(endpoint_level)] = net
Yuqi Li's avatar
Yuqi Li committed
800
801
802
803
        for key, tensor in intermediate_endpoints.items():
          endpoints[str(endpoint_level) + '/' + key] = tensor
        if current_stride != self._output_stride:
          endpoint_level += 1
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
804

Yuqi Li's avatar
Yuqi Li committed
805
806
    if str(endpoint_level) in endpoints:
      endpoint_level += 1
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
807
    return net, endpoints, endpoint_level
808
809
810
811

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
812
        'filter_size_scale': self._filter_size_scale,
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        'min_depth': self._min_depth,
        'output_stride': self._output_stride,
        'divisible_by': self._divisible_by,
        'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
        'regularize_depthwise': self._regularize_depthwise,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'finegrain_classification_mode': self._finegrain_classification_mode,
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs
Shixin Luo's avatar
Shixin Luo committed
836

837

Shixin Luo's avatar
Shixin Luo committed
838
839
840
@factory.register_backbone_builder('mobilenet')
def build_mobilenet(
    input_specs: tf.keras.layers.InputSpec,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
841
842
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
Rebecca Chen's avatar
Rebecca Chen committed
843
844
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
Fan Yang's avatar
Fan Yang committed
845
  """Builds MobileNet backbone from a config."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
846
847
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
Shixin Luo's avatar
Shixin Luo committed
848
  assert backbone_type == 'mobilenet', (f'Inconsistent backbone type '
849
                                        f'{backbone_type}')
Shixin Luo's avatar
Shixin Luo committed
850
851
852

  return MobileNet(
      model_id=backbone_cfg.model_id,
853
      filter_size_scale=backbone_cfg.filter_size_scale,
Shixin Luo's avatar
Shixin Luo committed
854
855
      input_specs=input_specs,
      stochastic_depth_drop_rate=backbone_cfg.stochastic_depth_drop_rate,
Yuqi Li's avatar
Yuqi Li committed
856
857
      output_stride=backbone_cfg.output_stride,
      output_intermediate_endpoints=backbone_cfg.output_intermediate_endpoints,
Shixin Luo's avatar
Shixin Luo committed
858
859
860
861
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)