darknet.py 21.5 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Yeqing Li's avatar
Yeqing Li committed
14

15
"""Contains definitions of Darknet Backbone Networks.
16

anivegesana's avatar
anivegesana committed
17
   The models are inspired by ResNet and CSPNet.
18
19
20
21
22
23

Residual networks (ResNets) were proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385

Cross Stage Partial networks (CSPNets) were proposed in:
24
25
26
27
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen,
    Jun-Wei Hsieh
    CSPNet: A New Backbone that can Enhance Learning Capability of CNN.
    arXiv:1911.11929
28
29


Jaeyoun Kim's avatar
Jaeyoun Kim committed
30
Darknets are used mainly for object detection in:
31
[1] Joseph Redmon, Ali Farhadi
32
    YOLOv3: An Incremental Improvement. arXiv:1804.02767
33
34
35
36
37

[2] Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao
    YOLOv4: Optimal Speed and Accuracy of Object Detection. arXiv:2004.10934
"""

Jaeyoun Kim's avatar
Jaeyoun Kim committed
38
import collections
Abdullah Rashwan's avatar
Abdullah Rashwan committed
39

vishnubanna's avatar
vishnubanna committed
40
import tensorflow as tf
vishnubanna's avatar
vishnubanna committed
41

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
42
from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
43
from official.projects.yolo.modeling.layers import nn_blocks
Abdullah Rashwan's avatar
Abdullah Rashwan committed
44
from official.vision.modeling.backbones import factory
vishnubanna's avatar
vishnubanna committed
45

46

Jaeyoun Kim's avatar
Jaeyoun Kim committed
47
48
class BlockConfig:
  """Class to store layer config to make code more readable."""
49
50

  def __init__(self, layer, stack, reps, bottleneck, filters, pool_size,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
               kernel_size, strides, padding, activation, route, dilation_rate,
               output_name, is_output):
    """Initializing method for BlockConfig.

    Args:
      layer: A `str` for layer name.
      stack: A `str` for the type of layer ordering to use for this specific
        level.
      reps: An `int` for the number of times to repeat block.
      bottleneck: A `bool` for whether this stack has a bottle neck layer.
      filters: An `int` for the output depth of the level.
      pool_size: An `int` for the pool_size of max pool layers.
      kernel_size: An `int` for convolution kernel size.
      strides: A `Union[int, tuple]` that indicates convolution strides.
      padding: An `int` for the padding to apply to layers in this stack.
      activation: A `str` for the activation to use for this stack.
      route: An `int` for the level to route from to get the next input.
      dilation_rate: An `int` for the scale used in dialated Darknet.
      output_name: A `str` for the name to use for this output.
      is_output: A `bool` for whether this layer is an output in the default
        model.
    """
73
74
75
76
77
78
    self.layer = layer
    self.stack = stack
    self.repetitions = reps
    self.bottleneck = bottleneck
    self.filters = filters
    self.kernel_size = kernel_size
vishnubanna's avatar
vishnubanna committed
79
    self.pool_size = pool_size
80
81
82
83
    self.strides = strides
    self.padding = padding
    self.activation = activation
    self.route = route
Jaeyoun Kim's avatar
Jaeyoun Kim committed
84
    self.dilation_rate = dilation_rate
85
86
87
    self.output_name = output_name
    self.is_output = is_output

vishnubanna's avatar
vishnubanna committed
88

vishnubanna's avatar
vishnubanna committed
89
def build_block_specs(config):
90
91
92
93
94
  specs = []
  for layer in config:
    specs.append(BlockConfig(*layer))
  return specs

95

Jaeyoun Kim's avatar
Jaeyoun Kim committed
96
97
class LayerBuilder:
  """Layer builder class.
98

Jaeyoun Kim's avatar
Jaeyoun Kim committed
99
100
101
  Class for quick look up of default layers used by darknet to
  connect, introduce or exit a level. Used in place of an if condition
  or switch to make adding new layers easier and to reduce redundant code.
vishnubanna's avatar
vishnubanna committed
102
  """
103

104
105
  def __init__(self):
    self._layer_dict = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
106
107
        'ConvBN': (nn_blocks.ConvBN, self.conv_bn_config_todict),
        'MaxPool': (tf.keras.layers.MaxPool2D, self.maxpool_config_todict)
108
109
    }

110
  def conv_bn_config_todict(self, config, kwargs):
vishnubanna's avatar
vishnubanna committed
111
    dictvals = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
112
113
114
115
        'filters': config.filters,
        'kernel_size': config.kernel_size,
        'strides': config.strides,
        'padding': config.padding
vishnubanna's avatar
vishnubanna committed
116
117
118
119
    }
    dictvals.update(kwargs)
    return dictvals

120
  def darktiny_config_todict(self, config, kwargs):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
121
    dictvals = {'filters': config.filters, 'strides': config.strides}
vishnubanna's avatar
vishnubanna committed
122
123
124
    dictvals.update(kwargs)
    return dictvals

125
  def maxpool_config_todict(self, config, kwargs):
vishnubanna's avatar
vishnubanna committed
126
    return {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
127
128
129
130
        'pool_size': config.pool_size,
        'strides': config.strides,
        'padding': config.padding,
        'name': kwargs['name']
vishnubanna's avatar
vishnubanna committed
131
    }
132
133

  def __call__(self, config, kwargs):
134
    layer, get_param_dict = self._layer_dict[config.layer]
135
136
    param_dict = get_param_dict(config, kwargs)
    return layer(**param_dict)
vishnubanna's avatar
vishnubanna committed
137
138


vishnubanna's avatar
vishnubanna committed
139
# model configs
140
LISTNAMES = [
Jaeyoun Kim's avatar
Jaeyoun Kim committed
141
142
143
    'default_layer_name', 'level_type', 'number_of_layers_in_level',
    'bottleneck', 'filters', 'kernal_size', 'pool_size', 'strides', 'padding',
    'default_activation', 'route', 'dilation', 'level/name', 'is_output'
144
145
]

vishnubanna's avatar
vishnubanna committed
146
CSPDARKNET53 = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 106,
        'neck_split': 132
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 32, None, 3, 1, 'same', 'mish', -1, 1, 0,
            False
        ],
        [
            'DarkRes', 'csp', 1, True, 64, None, None, None, None, 'mish', -1,
            1, 1, False
        ],
        [
            'DarkRes', 'csp', 2, False, 128, None, None, None, None, 'mish', -1,
            1, 2, False
        ],
        [
            'DarkRes', 'csp', 8, False, 256, None, None, None, None, 'mish', -1,
            1, 3, True
        ],
        [
            'DarkRes', 'csp', 8, False, 512, None, None, None, None, 'mish', -1,
            2, 4, True
        ],
        [
            'DarkRes', 'csp', 4, False, 1024, None, None, None, None, 'mish',
            -1, 4, 5, True
        ],
    ]
}

CSPADARKNET53 = {
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 100,
        'neck_split': 135
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 32, None, 3, 1, 'same', 'mish', -1, 1, 0,
            False
        ],
        [
            'DarkRes', 'residual', 1, True, 64, None, None, None, None, 'mish',
            -1, 1, 1, False
        ],
        [
            'DarkRes', 'csp', 2, False, 128, None, None, None, None, 'mish', -1,
            1, 2, False
        ],
        [
            'DarkRes', 'csp', 8, False, 256, None, None, None, None, 'mish', -1,
            1, 3, True
        ],
        [
            'DarkRes', 'csp', 8, False, 512, None, None, None, None, 'mish', -1,
            2, 4, True
        ],
        [
            'DarkRes', 'csp', 4, False, 1024, None, None, None, None, 'mish',
            -1, 4, 5, True
        ],
    ]
}

LARGECSP53 = {
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 100,
        'neck_split': 135
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 32, None, 3, 1, 'same', 'mish', -1, 1, 0,
            False
        ],
        [
229
            'DarkRes', 'csp', 1, False, 64, None, None, None, None, 'mish', -1,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            1, 1, False
        ],
        [
            'DarkRes', 'csp', 3, False, 128, None, None, None, None, 'mish', -1,
            1, 2, False
        ],
        [
            'DarkRes', 'csp', 15, False, 256, None, None, None, None, 'mish',
            -1, 1, 3, True
        ],
        [
            'DarkRes', 'csp', 15, False, 512, None, None, None, None, 'mish',
            -1, 2, 4, True
        ],
        [
            'DarkRes', 'csp', 7, False, 1024, None, None, None, None, 'mish',
            -1, 4, 5, True
        ],
        [
            'DarkRes', 'csp', 7, False, 1024, None, None, None, None, 'mish',
            -1, 8, 6, True
        ],
        [
            'DarkRes', 'csp', 7, False, 1024, None, None, None, None, 'mish',
            -1, 16, 7, True
        ],
256
    ]
vishnubanna's avatar
vishnubanna committed
257
258
259
}

DARKNET53 = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 76
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 32, None, 3, 1, 'same', 'leaky', -1, 1, 0,
            False
        ],
        [
            'DarkRes', 'residual', 1, True, 64, None, None, None, None, 'leaky',
            -1, 1, 1, False
        ],
        [
            'DarkRes', 'residual', 2, False, 128, None, None, None, None,
            'leaky', -1, 1, 2, False
        ],
        [
            'DarkRes', 'residual', 8, False, 256, None, None, None, None,
            'leaky', -1, 1, 3, True
        ],
        [
            'DarkRes', 'residual', 8, False, 512, None, None, None, None,
            'leaky', -1, 2, 4, True
        ],
        [
            'DarkRes', 'residual', 4, False, 1024, None, None, None, None,
            'leaky', -1, 4, 5, True
        ],
290
    ]
vishnubanna's avatar
vishnubanna committed
291
292
293
}

CSPDARKNETTINY = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 28
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 32, None, 3, 2, 'same', 'leaky', -1, 1, 0,
            False
        ],
        [
            'ConvBN', None, 1, False, 64, None, 3, 2, 'same', 'leaky', -1, 1, 1,
            False
        ],
        [
            'CSPTiny', 'csp_tiny', 1, False, 64, None, 3, 2, 'same', 'leaky',
            -1, 1, 2, False
        ],
        [
            'CSPTiny', 'csp_tiny', 1, False, 128, None, 3, 2, 'same', 'leaky',
            -1, 1, 3, False
        ],
        [
            'CSPTiny', 'csp_tiny', 1, False, 256, None, 3, 2, 'same', 'leaky',
            -1, 1, 4, True
        ],
        [
            'ConvBN', None, 1, False, 512, None, 3, 1, 'same', 'leaky', -1, 1,
            5, True
        ],
324
    ]
vishnubanna's avatar
vishnubanna committed
325
326
327
}

DARKNETTINY = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    'list_names':
        LISTNAMES,
    'splits': {
        'backbone_split': 14
    },
    'backbone': [
        [
            'ConvBN', None, 1, False, 16, None, 3, 1, 'same', 'leaky', -1, 1, 0,
            False
        ],
        [
            'DarkTiny', 'tiny', 1, True, 32, None, 3, 2, 'same', 'leaky', -1, 1,
            1, False
        ],
        [
            'DarkTiny', 'tiny', 1, True, 64, None, 3, 2, 'same', 'leaky', -1, 1,
            2, False
        ],
        [
            'DarkTiny', 'tiny', 1, False, 128, None, 3, 2, 'same', 'leaky', -1,
            1, 3, False
        ],
        [
            'DarkTiny', 'tiny', 1, False, 256, None, 3, 2, 'same', 'leaky', -1,
            1, 4, True
        ],
        [
            'DarkTiny', 'tiny', 1, False, 512, None, 3, 2, 'same', 'leaky', -1,
            1, 5, False
        ],
        [
            'DarkTiny', 'tiny', 1, False, 1024, None, 3, 1, 'same', 'leaky', -1,
            1, 5, True
        ],
362
363
364
365
    ]
}

BACKBONES = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
366
367
368
369
370
371
    'darknettiny': DARKNETTINY,
    'darknet53': DARKNET53,
    'cspdarknet53': CSPDARKNET53,
    'altered_cspdarknet53': CSPADARKNET53,
    'cspdarknettiny': CSPDARKNETTINY,
    'csp-large': LARGECSP53,
vishnubanna's avatar
vishnubanna committed
372
373
374
}


375
class Darknet(tf.keras.Model):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
376
  """The Darknet backbone architecture."""
377
378
379

  def __init__(
      self,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
380
      model_id='darknet53',
vishnubanna's avatar
vishnubanna committed
381
      input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
382
383
      min_level=None,
      max_level=5,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
384
385
      width_scale=1.0,
      depth_scale=1.0,
386
      use_reorg_input=False,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
387
      csp_level_mod=(),
388
389
      activation=None,
      use_sync_bn=False,
Vishnu Banna's avatar
Vishnu Banna committed
390
      use_separable_conv=False,
391
392
      norm_momentum=0.99,
      norm_epsilon=0.001,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
393
      dilate=False,
Vishnu Banna's avatar
Vishnu Banna committed
394
      kernel_initializer='VarianceScaling',
395
396
397
398
      kernel_regularizer=None,
      bias_regularizer=None,
      **kwargs):

Tyan3001's avatar
Tyan3001 committed
399
    layer_specs, splits = Darknet.get_model_config(model_id)
anivegesana's avatar
anivegesana committed
400

401
402
    self._model_name = model_id
    self._splits = splits
vishnubanna's avatar
vishnubanna committed
403
    self._input_shape = input_specs
Jaeyoun Kim's avatar
Jaeyoun Kim committed
404
    self._registry = LayerBuilder()
405
406

    # default layer look up
407
408
    self._min_size = min_level
    self._max_size = max_level
409
    self._output_specs = None
Jaeyoun Kim's avatar
Jaeyoun Kim committed
410
    self._csp_level_mod = set(csp_level_mod)
411
412
413
414
415
416

    self._kernel_initializer = kernel_initializer
    self._bias_regularizer = bias_regularizer
    self._norm_momentum = norm_momentum
    self._norm_epislon = norm_epsilon
    self._use_sync_bn = use_sync_bn
Vishnu Banna's avatar
Vishnu Banna committed
417
    self._use_separable_conv = use_separable_conv
418
    self._activation = activation
vishnubanna's avatar
vishnubanna committed
419
    self._kernel_regularizer = kernel_regularizer
Jaeyoun Kim's avatar
Jaeyoun Kim committed
420
421
422
    self._dilate = dilate
    self._width_scale = width_scale
    self._depth_scale = depth_scale
Vishnu Banna's avatar
Vishnu Banna committed
423
    self._use_reorg_input = use_reorg_input
424
425

    self._default_dict = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
426
427
428
429
430
431
432
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epislon,
        'use_sync_bn': self._use_sync_bn,
        'activation': self._activation,
Vishnu Banna's avatar
Vishnu Banna committed
433
        'use_separable_conv': self._use_separable_conv,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
434
435
        'dilation_rate': 1,
        'name': None
436
437
    }

438
    inputs = tf.keras.layers.Input(shape=self._input_shape.shape[1:])
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    output = self._build_struct(layer_specs, inputs)
    super().__init__(inputs=inputs, outputs=output, name=self._model_name)

  @property
  def input_specs(self):
    return self._input_shape

  @property
  def output_specs(self):
    return self._output_specs

  @property
  def splits(self):
    return self._splits

  def _build_struct(self, net, inputs):
Vishnu Banna's avatar
Vishnu Banna committed
455
456
    if self._use_reorg_input:
      inputs = nn_blocks.Reorg()(inputs)
Vishnu Banna's avatar
Vishnu Banna committed
457
458
459
      net[0].filters = net[1].filters
      net[0].output_name = net[1].output_name
      del net[1]
Vishnu Banna's avatar
Vishnu Banna committed
460

461
462
463
    endpoints = collections.OrderedDict()
    stack_outputs = [inputs]
    for i, config in enumerate(net):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
464
465
466
467
468
469
470
471
      if config.output_name > self._max_size:
        break
      if config.output_name in self._csp_level_mod:
        config.stack = 'residual'

      config.filters = int(config.filters * self._width_scale)
      config.repetitions = int(config.repetitions * self._depth_scale)

472
      if config.stack is None:
Jaeyoun Kim's avatar
Jaeyoun Kim committed
473
474
        x = self._build_block(
            stack_outputs[config.route], config, name=f'{config.layer}_{i}')
475
        stack_outputs.append(x)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
476
477
478
      elif config.stack == 'residual':
        x = self._residual_stack(
            stack_outputs[config.route], config, name=f'{config.layer}_{i}')
479
        stack_outputs.append(x)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
480
481
482
      elif config.stack == 'csp':
        x = self._csp_stack(
            stack_outputs[config.route], config, name=f'{config.layer}_{i}')
483
        stack_outputs.append(x)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
484
485
486
      elif config.stack == 'csp_tiny':
        x_pass, x = self._csp_tiny_stack(
            stack_outputs[config.route], config, name=f'{config.layer}_{i}')
487
        stack_outputs.append(x_pass)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
488
489
490
      elif config.stack == 'tiny':
        x = self._tiny_stack(
            stack_outputs[config.route], config, name=f'{config.layer}_{i}')
491
        stack_outputs.append(x)
492
      if (config.is_output and self._min_size is None):
vishnubanna's avatar
vishnubanna committed
493
        endpoints[str(config.output_name)] = x
Jaeyoun Kim's avatar
Jaeyoun Kim committed
494
495
496
      elif (self._min_size is not None and
            config.output_name >= self._min_size and
            config.output_name <= self._max_size):
vishnubanna's avatar
vishnubanna committed
497
        endpoints[str(config.output_name)] = x
498
499
500
501
502

    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints.keys()}
    return endpoints

  def _get_activation(self, activation):
503
    if self._activation is None:
504
      return activation
Jaeyoun Kim's avatar
Jaeyoun Kim committed
505
    return self._activation
506
507
508

  def _csp_stack(self, inputs, config, name):
    if config.bottleneck:
vishnubanna's avatar
vishnubanna committed
509
510
      csp_filter_scale = 1
      residual_filter_scale = 2
511
512
      scale_filters = 1
    else:
vishnubanna's avatar
vishnubanna committed
513
514
      csp_filter_scale = 2
      residual_filter_scale = 1
515
      scale_filters = 2
Jaeyoun Kim's avatar
Jaeyoun Kim committed
516
517
518
519
    self._default_dict['activation'] = self._get_activation(config.activation)
    self._default_dict['name'] = f'{name}_csp_down'
    if self._dilate:
      self._default_dict['dilation_rate'] = config.dilation_rate
Vishnu Banna's avatar
Vishnu Banna committed
520
      degrid = int(tf.math.log(float(config.dilation_rate)) / tf.math.log(2.))
Jaeyoun Kim's avatar
Jaeyoun Kim committed
521
522
    else:
      self._default_dict['dilation_rate'] = 1
Vishnu Banna's avatar
Vishnu Banna committed
523
      degrid = 0
Jaeyoun Kim's avatar
Jaeyoun Kim committed
524

Vishnu Banna's avatar
Vishnu Banna committed
525
    # swap/add dialation
Jaeyoun Kim's avatar
Jaeyoun Kim committed
526
527
528
529
530
531
532
    x, x_route = nn_blocks.CSPRoute(
        filters=config.filters,
        filter_scale=csp_filter_scale,
        downsample=True,
        **self._default_dict)(
            inputs)

Vishnu Banna's avatar
Vishnu Banna committed
533
    dilated_reps = config.repetitions - degrid
Jaeyoun Kim's avatar
Jaeyoun Kim committed
534
535
536
537
538
539
540
541
542
    for i in range(dilated_reps):
      self._default_dict['name'] = f'{name}_{i}'
      x = nn_blocks.DarkResidual(
          filters=config.filters // scale_filters,
          filter_scale=residual_filter_scale,
          **self._default_dict)(
              x)

    for i in range(dilated_reps, config.repetitions):
Vishnu Banna's avatar
Vishnu Banna committed
543
544
      self._default_dict['dilation_rate'] = max(
          1, self._default_dict['dilation_rate'] // 2)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
      self._default_dict[
          'name'] = f"{name}_{i}_degridded_{self._default_dict['dilation_rate']}"
      x = nn_blocks.DarkResidual(
          filters=config.filters // scale_filters,
          filter_scale=residual_filter_scale,
          **self._default_dict)(
              x)

    self._default_dict['name'] = f'{name}_csp_connect'
    output = nn_blocks.CSPConnect(
        filters=config.filters,
        filter_scale=csp_filter_scale,
        **self._default_dict)([x, x_route])
    self._default_dict['activation'] = self._activation
    self._default_dict['name'] = None
560
561
    return output

562
  def _csp_tiny_stack(self, inputs, config, name):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
563
564
565
566
567
568
569
    self._default_dict['activation'] = self._get_activation(config.activation)
    self._default_dict['name'] = f'{name}_csp_tiny'
    x, x_route = nn_blocks.CSPTiny(
        filters=config.filters, **self._default_dict)(
            inputs)
    self._default_dict['activation'] = self._activation
    self._default_dict['name'] = None
570
571
    return x, x_route

572
  def _tiny_stack(self, inputs, config, name):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
573
574
575
576
577
578
579
580
581
    x = tf.keras.layers.MaxPool2D(
        pool_size=2,
        strides=config.strides,
        padding='same',
        data_format=None,
        name=f'{name}_tiny/pool')(
            inputs)
    self._default_dict['activation'] = self._get_activation(config.activation)
    self._default_dict['name'] = f'{name}_tiny/conv'
582
583
584
585
    x = nn_blocks.ConvBN(
        filters=config.filters,
        kernel_size=(3, 3),
        strides=(1, 1),
Jaeyoun Kim's avatar
Jaeyoun Kim committed
586
        padding='same',
587
588
        **self._default_dict)(
            x)
Jaeyoun Kim's avatar
Jaeyoun Kim committed
589
590
    self._default_dict['activation'] = self._activation
    self._default_dict['name'] = None
591
592
    return x

593
  def _residual_stack(self, inputs, config, name):
Jaeyoun Kim's avatar
Jaeyoun Kim committed
594
595
596
597
598
599
600
601
602
603
604
605
606
    self._default_dict['activation'] = self._get_activation(config.activation)
    self._default_dict['name'] = f'{name}_residual_down'
    if self._dilate:
      self._default_dict['dilation_rate'] = config.dilation_rate
      if config.repetitions < 8:
        config.repetitions += 2
    else:
      self._default_dict['dilation_rate'] = 1

    x = nn_blocks.DarkResidual(
        filters=config.filters, downsample=True, **self._default_dict)(
            inputs)

Vishnu Banna's avatar
Vishnu Banna committed
607
608
    dilated_reps = config.repetitions - self._default_dict[
        'dilation_rate'] // 2 - 1
Jaeyoun Kim's avatar
Jaeyoun Kim committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    for i in range(dilated_reps):
      self._default_dict['name'] = f'{name}_{i}'
      x = nn_blocks.DarkResidual(
          filters=config.filters, **self._default_dict)(
              x)

    for i in range(dilated_reps, config.repetitions - 1):
      self._default_dict[
          'dilation_rate'] = self._default_dict['dilation_rate'] // 2
      self._default_dict[
          'name'] = f"{name}_{i}_degridded_{self._default_dict['dilation_rate']}"
      x = nn_blocks.DarkResidual(
          filters=config.filters, **self._default_dict)(
              x)

    self._default_dict['activation'] = self._activation
    self._default_dict['name'] = None
    self._default_dict['dilation_rate'] = 1
627
628
629
630
631
    return x

  def _build_block(self, inputs, config, name):
    x = inputs
    i = 0
Jaeyoun Kim's avatar
Jaeyoun Kim committed
632
    self._default_dict['activation'] = self._get_activation(config.activation)
633
    while i < config.repetitions:
Jaeyoun Kim's avatar
Jaeyoun Kim committed
634
      self._default_dict['name'] = f'{name}_{i}'
635
636
637
      layer = self._registry(config, self._default_dict)
      x = layer(x)
      i += 1
Jaeyoun Kim's avatar
Jaeyoun Kim committed
638
639
    self._default_dict['activation'] = self._activation
    self._default_dict['name'] = None
640
641
642
643
644
    return x

  @staticmethod
  def get_model_config(name):
    name = name.lower()
Jaeyoun Kim's avatar
Jaeyoun Kim committed
645
646
    backbone = BACKBONES[name]['backbone']
    splits = BACKBONES[name]['splits']
647
    return build_block_specs(backbone), splits
648

649
650
651
652
653
654
655
656
657
658
  @property
  def model_id(self):
    return self._model_name

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  def get_config(self):
    layer_config = {
Jaeyoun Kim's avatar
Jaeyoun Kim committed
659
660
661
662
663
664
665
666
667
668
        'model_id': self._model_name,
        'min_level': self._min_size,
        'max_level': self._max_size,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epislon,
        'use_sync_bn': self._use_sync_bn,
        'activation': self._activation,
669
670
    }
    return layer_config
vishnubanna's avatar
vishnubanna committed
671

672

Jaeyoun Kim's avatar
Jaeyoun Kim committed
673
@factory.register_backbone_builder('darknet')
vishnubanna's avatar
vishnubanna committed
674
def build_darknet(
675
    input_specs: tf.keras.layers.InputSpec,
Vishnu Banna's avatar
Vishnu Banna committed
676
    backbone_config: hyperparams.Config,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
677
    norm_activation_config: hyperparams.Config,
678
679
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
Jaeyoun Kim's avatar
Jaeyoun Kim committed
680
  """Builds darknet."""
681

Vishnu Banna's avatar
Vishnu Banna committed
682
  backbone_config = backbone_config.get()
683
  model = Darknet(
Vishnu Banna's avatar
Vishnu Banna committed
684
685
686
      model_id=backbone_config.model_id,
      min_level=backbone_config.min_level,
      max_level=backbone_config.max_level,
Jaeyoun Kim's avatar
Jaeyoun Kim committed
687
      input_specs=input_specs,
Vishnu Banna's avatar
Vishnu Banna committed
688
689
690
691
      dilate=backbone_config.dilate,
      width_scale=backbone_config.width_scale,
      depth_scale=backbone_config.depth_scale,
      use_reorg_input=backbone_config.use_reorg_input,
692
693
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
Vishnu Banna's avatar
Vishnu Banna committed
694
      use_separable_conv=backbone_config.use_separable_conv,
695
696
697
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
vishnubanna's avatar
vishnubanna committed
698
  return model