resnet.py 15.2 KB
Newer Older
Yeqing Li's avatar
Yeqing Li committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
"""Contains definitions of ResNet and ResNet-RS models."""
Abdullah Rashwan's avatar
Abdullah Rashwan committed
16

Fan Yang's avatar
Fan Yang committed
17
18
from typing import Callable, Optional

Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
20
# Import libraries
import tensorflow as tf
Fan Yang's avatar
Fan Yang committed
21
22

from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
from official.modeling import tf_utils
Yeqing Li's avatar
Yeqing Li committed
24
from official.vision.beta.modeling.backbones import factory
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
from official.vision.beta.modeling.layers import nn_blocks
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
26
from official.vision.beta.modeling.layers import nn_layers
Abdullah Rashwan's avatar
Abdullah Rashwan committed
27
28
29
30
31
32
33
34

layers = tf.keras.layers

# Specifications for different ResNet variants.
# Each entry specifies block configurations of the particular ResNet variant.
# Each element in the block configuration is in the following format:
# (block_fn, num_filters, block_repeats)
RESNET_SPECS = {
Fan Yang's avatar
Fan Yang committed
35
36
37
38
39
40
    10: [
        ('residual', 64, 1),
        ('residual', 128, 1),
        ('residual', 256, 1),
        ('residual', 512, 1),
    ],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    18: [
        ('residual', 64, 2),
        ('residual', 128, 2),
        ('residual', 256, 2),
        ('residual', 512, 2),
    ],
    34: [
        ('residual', 64, 3),
        ('residual', 128, 4),
        ('residual', 256, 6),
        ('residual', 512, 3),
    ],
    50: [
        ('bottleneck', 64, 3),
        ('bottleneck', 128, 4),
        ('bottleneck', 256, 6),
        ('bottleneck', 512, 3),
    ],
    101: [
        ('bottleneck', 64, 3),
        ('bottleneck', 128, 4),
        ('bottleneck', 256, 23),
        ('bottleneck', 512, 3),
    ],
    152: [
        ('bottleneck', 64, 3),
        ('bottleneck', 128, 8),
        ('bottleneck', 256, 36),
        ('bottleneck', 512, 3),
    ],
    200: [
        ('bottleneck', 64, 3),
        ('bottleneck', 128, 24),
        ('bottleneck', 256, 36),
        ('bottleneck', 512, 3),
    ],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
    270: [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
78
        ('bottleneck', 64, 4),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
80
        ('bottleneck', 128, 29),
        ('bottleneck', 256, 53),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
81
82
        ('bottleneck', 512, 4),
    ],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
83
84
85
86
87
88
    350: [
        ('bottleneck', 64, 4),
        ('bottleneck', 128, 36),
        ('bottleneck', 256, 72),
        ('bottleneck', 512, 4),
    ],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
90
91
92
93
94
    420: [
        ('bottleneck', 64, 4),
        ('bottleneck', 128, 44),
        ('bottleneck', 256, 87),
        ('bottleneck', 512, 4),
    ],
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
96
97
98
99
}


@tf.keras.utils.register_keras_serializable(package='Vision')
class ResNet(tf.keras.Model):
100
  """Creates ResNet and ResNet-RS family models.
Fan Yang's avatar
Fan Yang committed
101
102
103
104

  This implements the Deep Residual Network from:
    Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
    Deep Residual Learning for Image Recognition.
105
106
107
108
109
    (https://arxiv.org/pdf/1512.03385) and
    Irwan Bello, William Fedus, Xianzhi Du, Ekin D. Cubuk, Aravind Srinivas,
    Tsung-Yi Lin, Jonathon Shlens, Barret Zoph.
    Revisiting ResNets: Improved Training and Scaling Strategies.
    (https://arxiv.org/abs/2103.07579).
Fan Yang's avatar
Fan Yang committed
110
  """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
111

Fan Yang's avatar
Fan Yang committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  def __init__(
      self,
      model_id: int,
      input_specs: tf.keras.layers.InputSpec = layers.InputSpec(
          shape=[None, None, None, 3]),
      depth_multiplier: float = 1.0,
      stem_type: str = 'v0',
      resnetd_shortcut: bool = False,
      replace_stem_max_pool: bool = False,
      se_ratio: Optional[float] = None,
      init_stochastic_depth_rate: float = 0.0,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
      bn_trainable: bool = True,
Fan Yang's avatar
Fan Yang committed
131
      **kwargs):
Fan Yang's avatar
Fan Yang committed
132
    """Initializes a ResNet model.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
133
134

    Args:
Fan Yang's avatar
Fan Yang committed
135
136
137
      model_id: An `int` of the depth of ResNet backbone model.
      input_specs: A `tf.keras.layers.InputSpec` of the input tensor.
      depth_multiplier: A `float` of the depth multiplier to uniformaly scale up
138
139
        all layers in channel size. This argument is also referred to as
        `width_multiplier` in (https://arxiv.org/abs/2103.07579).
Fan Yang's avatar
Fan Yang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      resnetd_shortcut: A `bool` of whether to use ResNet-D shortcut in
        downsampling blocks.
      replace_stem_max_pool: A `bool` of whether to replace the max pool in stem
        with a stride-2 conv,
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` name of the activation function.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A small `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
157
158
      bn_trainable: A `bool` that indicates whether batch norm layers should be
        trainable. Default to True.
Fan Yang's avatar
Fan Yang committed
159
      **kwargs: Additional keyword arguments to be passed.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
160
161
162
    """
    self._model_id = model_id
    self._input_specs = input_specs
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
163
    self._depth_multiplier = depth_multiplier
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
164
    self._stem_type = stem_type
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
165
166
    self._resnetd_shortcut = resnetd_shortcut
    self._replace_stem_max_pool = replace_stem_max_pool
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
167
    self._se_ratio = se_ratio
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
168
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
Abdullah Rashwan's avatar
Abdullah Rashwan committed
169
170
171
172
173
174
175
176
177
178
179
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    if use_sync_bn:
      self._norm = layers.experimental.SyncBatchNormalization
    else:
      self._norm = layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
180
    self._bn_trainable = bn_trainable
Abdullah Rashwan's avatar
Abdullah Rashwan committed
181
182
183
184
185
186
187
188
189

    if tf.keras.backend.image_data_format() == 'channels_last':
      bn_axis = -1
    else:
      bn_axis = 1

    # Build ResNet.
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
190
191
    if stem_type == 'v0':
      x = layers.Conv2D(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
          filters=int(64 * self._depth_multiplier),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
193
194
195
196
197
198
199
200
201
          kernel_size=7,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
202
203
204
205
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          trainable=bn_trainable)(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
206
              x)
207
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
208
209
    elif stem_type == 'v1':
      x = layers.Conv2D(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
210
          filters=int(32 * self._depth_multiplier),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
211
212
213
214
215
216
217
218
219
          kernel_size=3,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
220
221
222
223
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          trainable=bn_trainable)(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
224
              x)
225
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
226
      x = layers.Conv2D(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
227
          filters=int(32 * self._depth_multiplier),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
228
229
230
231
232
233
234
235
236
          kernel_size=3,
          strides=1,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
237
238
239
240
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          trainable=bn_trainable)(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
241
              x)
242
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
243
      x = layers.Conv2D(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
244
          filters=int(64 * self._depth_multiplier),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
245
246
247
248
249
250
251
252
253
          kernel_size=3,
          strides=1,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
254
255
256
257
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          trainable=bn_trainable)(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
258
              x)
259
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
260
261
262
    else:
      raise ValueError('Stem type {} not supported.'.format(stem_type))

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
263
264
265
266
267
268
269
270
271
272
273
274
    if replace_stem_max_pool:
      x = layers.Conv2D(
          filters=int(64 * self._depth_multiplier),
          kernel_size=3,
          strides=2,
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
275
276
277
278
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          trainable=bn_trainable)(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
279
              x)
280
      x = tf_utils.get_activation(activation, use_keras_layer=True)(x)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
281
282
    else:
      x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
283
284
285
286
287
288
289
290
291
292
293

    endpoints = {}
    for i, spec in enumerate(RESNET_SPECS[model_id]):
      if spec[0] == 'residual':
        block_fn = nn_blocks.ResidualBlock
      elif spec[0] == 'bottleneck':
        block_fn = nn_blocks.BottleneckBlock
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(spec[0]))
      x = self._block_group(
          inputs=x,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
294
          filters=int(spec[1] * self._depth_multiplier),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
295
296
297
          strides=(1 if i == 0 else 2),
          block_fn=block_fn,
          block_repeats=spec[2],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
298
299
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 2, 5),
Abdullah Rashwan's avatar
Abdullah Rashwan committed
300
          name='block_group_l{}'.format(i + 2))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
301
      endpoints[str(i + 2)] = x
Abdullah Rashwan's avatar
Abdullah Rashwan committed
302
303
304
305
306
307

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(ResNet, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)

  def _block_group(self,
Fan Yang's avatar
Fan Yang committed
308
309
310
311
312
313
314
                   inputs: tf.Tensor,
                   filters: int,
                   strides: int,
                   block_fn: Callable[..., tf.keras.layers.Layer],
                   block_repeats: int = 1,
                   stochastic_depth_drop_rate: float = 0.0,
                   name: str = 'block_group'):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
315
316
317
    """Creates one group of blocks for the ResNet model.

    Args:
Fan Yang's avatar
Fan Yang committed
318
319
320
321
322
323
324
325
326
327
328
      inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
      filters: An `int` number of filters for the first convolution of the
        layer.
      strides: An `int` stride to use for the first convolution of the layer.
        If greater than 1, this layer will downsample the input.
      block_fn: The type of block group. Either `nn_blocks.ResidualBlock` or
        `nn_blocks.BottleneckBlock`.
      block_repeats: An `int` number of blocks contained in the layer.
      stochastic_depth_drop_rate: A `float` of drop rate of the current block
        group.
      name: A `str` name for the block.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
329
330

    Returns:
Fan Yang's avatar
Fan Yang committed
331
      The output `tf.Tensor` of the block layer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
332
333
334
335
336
    """
    x = block_fn(
        filters=filters,
        strides=strides,
        use_projection=True,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
337
        stochastic_depth_drop_rate=stochastic_depth_drop_rate,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
338
        se_ratio=self._se_ratio,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
339
        resnetd_shortcut=self._resnetd_shortcut,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
340
341
342
343
344
345
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activation=self._activation,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
346
347
        norm_epsilon=self._norm_epsilon,
        bn_trainable=self._bn_trainable)(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
348
349
350
351
352
353
354
            inputs)

    for _ in range(1, block_repeats):
      x = block_fn(
          filters=filters,
          strides=1,
          use_projection=False,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
355
          stochastic_depth_drop_rate=stochastic_depth_drop_rate,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
356
          se_ratio=self._se_ratio,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
357
          resnetd_shortcut=self._resnetd_shortcut,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
358
359
360
361
362
363
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activation=self._activation,
          use_sync_bn=self._use_sync_bn,
          norm_momentum=self._norm_momentum,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
364
365
          norm_epsilon=self._norm_epsilon,
          bn_trainable=self._bn_trainable)(
Abdullah Rashwan's avatar
Abdullah Rashwan committed
366
367
              x)

368
    return tf.keras.layers.Activation('linear', name=name)(x)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
369
370
371
372

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
373
        'depth_multiplier': self._depth_multiplier,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
374
        'stem_type': self._stem_type,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
375
376
        'resnetd_shortcut': self._resnetd_shortcut,
        'replace_stem_max_pool': self._replace_stem_max_pool,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
377
        'activation': self._activation,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
378
        'se_ratio': self._se_ratio,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
379
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
380
381
382
383
384
385
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
386
        'bn_trainable': self._bn_trainable
Abdullah Rashwan's avatar
Abdullah Rashwan committed
387
388
389
390
391
392
393
394
395
396
397
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs
Yeqing Li's avatar
Yeqing Li committed
398
399
400
401
402


@factory.register_backbone_builder('resnet')
def build_resnet(
    input_specs: tf.keras.layers.InputSpec,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
403
404
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
Yeqing Li's avatar
Yeqing Li committed
405
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
406
  """Builds ResNet backbone from a config."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
407
408
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
Yeqing Li's avatar
Yeqing Li committed
409
410
411
412
413
414
  assert backbone_type == 'resnet', (f'Inconsistent backbone type '
                                     f'{backbone_type}')

  return ResNet(
      model_id=backbone_cfg.model_id,
      input_specs=input_specs,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
415
      depth_multiplier=backbone_cfg.depth_multiplier,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
416
      stem_type=backbone_cfg.stem_type,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
417
418
      resnetd_shortcut=backbone_cfg.resnetd_shortcut,
      replace_stem_max_pool=backbone_cfg.replace_stem_max_pool,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
419
      se_ratio=backbone_cfg.se_ratio,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
420
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
Yeqing Li's avatar
Yeqing Li committed
421
422
423
424
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
425
426
      kernel_regularizer=l2_regularizer,
      bn_trainable=backbone_cfg.bn_trainable)