yolo.py 17.2 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Vishnu Banna's avatar
Vishnu Banna committed
15
"""YOLO configuration definition."""
16
17
import dataclasses
import os
Vishnu Banna's avatar
Vishnu Banna committed
18
from typing import Any, List, Optional, Union
Vishnu Banna's avatar
Vishnu Banna committed
19

20
21
22
import numpy as np

from official.core import config_definitions as cfg
Vishnu Banna's avatar
Vishnu Banna committed
23
24
from official.core import exp_factory
from official.modeling import hyperparams
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
26
27
from official.projects.yolo import optimization
from official.projects.yolo.configs import backbones
from official.projects.yolo.configs import decoders
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
from official.vision.configs import common
29
30
31


# pytype: disable=annotation-type-mismatch
Vishnu Banna's avatar
Vishnu Banna committed
32
33
34

MIN_LEVEL = 1
MAX_LEVEL = 7
Vishnu Banna's avatar
Vishnu Banna committed
35
GLOBAL_SEED = 1000
Vishnu Banna's avatar
Vishnu Banna committed
36

37

Vishnu Banna's avatar
Vishnu Banna committed
38
39
def _build_dict(min_level, max_level, value):
  vals = {str(key): value for key in range(min_level, max_level + 1)}
40
  vals['all'] = None
Vishnu Banna's avatar
Vishnu Banna committed
41
42
  return lambda: vals

43

Vishnu Banna's avatar
Vishnu Banna committed
44
45
46
def _build_path_scales(min_level, max_level):
  return lambda: {str(key): 2**key for key in range(min_level, max_level + 1)}

47

Vishnu Banna's avatar
Vishnu Banna committed
48
49
@dataclasses.dataclass
class FPNConfig(hyperparams.Config):
50
51
  """FPN config."""
  all: Optional[Any] = None
Vishnu Banna's avatar
Vishnu Banna committed
52

Vishnu Banna's avatar
Vishnu Banna committed
53
  def get(self):
Vishnu Banna's avatar
Vishnu Banna committed
54
    """Allow for a key for each level or a single key for all the levels."""
Vishnu Banna's avatar
Vishnu Banna committed
55
    values = self.as_dict()
56
    if 'all' in values and values['all'] is not None:
Vishnu Banna's avatar
Vishnu Banna committed
57
58
      for key in values:
        if key != 'all':
59
          values[key] = values['all']
Vishnu Banna's avatar
Vishnu Banna committed
60
61
    return values

62

Vishnu Banna's avatar
Vishnu Banna committed
63
64
65
66
# pylint: disable=missing-class-docstring
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
  regenerate_source_id: bool = False
67
68
  coco91_to_80: bool = True

Vishnu Banna's avatar
Vishnu Banna committed
69
70
71
72
73
74

@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
  regenerate_source_id: bool = False
  label_map: str = ''

75

Vishnu Banna's avatar
Vishnu Banna committed
76
77
78
79
80
81
@dataclasses.dataclass
class DataDecoder(hyperparams.OneOfConfig):
  type: Optional[str] = 'simple_decoder'
  simple_decoder: TfExampleDecoder = TfExampleDecoder()
  label_map_decoder: TfExampleDecoderLabelMap = TfExampleDecoderLabelMap()

82

Vishnu Banna's avatar
Vishnu Banna committed
83
84
85
86
87
88
89
90
91
92
@dataclasses.dataclass
class Mosaic(hyperparams.Config):
  mosaic_frequency: float = 0.0
  mixup_frequency: float = 0.0
  mosaic_center: float = 0.2
  mosaic_crop_mode: Optional[str] = None
  aug_scale_min: float = 1.0
  aug_scale_max: float = 1.0
  jitter: float = 0.0

93

Vishnu Banna's avatar
Vishnu Banna committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@dataclasses.dataclass
class Parser(hyperparams.Config):
  max_num_instances: int = 200
  letter_box: Optional[bool] = True
  random_flip: bool = True
  random_pad: float = False
  jitter: float = 0.0
  aug_scale_min: float = 1.0
  aug_scale_max: float = 1.0
  aug_rand_saturation: float = 0.0
  aug_rand_brightness: float = 0.0
  aug_rand_hue: float = 0.0
  aug_rand_angle: float = 0.0
  aug_rand_translate: float = 0.0
  aug_rand_perspective: float = 0.0
  use_tie_breaker: bool = True
  best_match_only: bool = False
  anchor_thresh: float = -0.01
  area_thresh: float = 0.1
  mosaic: Mosaic = Mosaic()

115

Vishnu Banna's avatar
Vishnu Banna committed
116
117
118
119
120
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
  """Input config for training."""
  global_batch_size: int = 64
  input_path: str = ''
121
122
  tfds_name: str = ''
  tfds_split: str = ''
Vishnu Banna's avatar
Vishnu Banna committed
123
124
125
126
127
128
129
130
  global_batch_size: int = 1
  is_training: bool = True
  dtype: str = 'float16'
  decoder: DataDecoder = DataDecoder()
  parser: Parser = Parser()
  shuffle_buffer_size: int = 10000
  tfds_download: bool = True
  cache: bool = False
131
  drop_remainder: bool = True
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132
  file_type: str = 'tfrecord'
133

Vishnu Banna's avatar
Vishnu Banna committed
134
135
136
137
138
139

@dataclasses.dataclass
class YoloHead(hyperparams.Config):
  """Parameterization for the YOLO Head."""
  smart_bias: bool = True

140

Vishnu Banna's avatar
Vishnu Banna committed
141
142
143
@dataclasses.dataclass
class YoloDetectionGenerator(hyperparams.Config):
  box_type: FPNConfig = dataclasses.field(
144
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 'original'))
Vishnu Banna's avatar
Vishnu Banna committed
145
146
147
148
149
150
151
152
153
154
  scale_xy: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
  path_scales: FPNConfig = dataclasses.field(
      default_factory=_build_path_scales(MIN_LEVEL, MAX_LEVEL))
  nms_type: str = 'greedy'
  iou_thresh: float = 0.001
  nms_thresh: float = 0.6
  max_boxes: int = 200
  pre_nms_points: int = 5000

155

Vishnu Banna's avatar
Vishnu Banna committed
156
157
158
159
160
161
162
163
164
165
166
167
@dataclasses.dataclass
class YoloLoss(hyperparams.Config):
  ignore_thresh: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 0.0))
  truth_thresh: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
  box_loss_type: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 'ciou'))
  iou_normalizer: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
  cls_normalizer: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
Vishnu Banna's avatar
Vishnu Banna committed
168
  object_normalizer: FPNConfig = dataclasses.field(
Vishnu Banna's avatar
Vishnu Banna committed
169
170
171
172
173
174
175
176
177
178
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
  max_delta: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, np.inf))
  objectness_smooth: FPNConfig = dataclasses.field(
      default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 0.0))
  label_smoothing: float = 0.0
  use_scaled_loss: bool = True
  update_on_repeat: bool = True


179
@dataclasses.dataclass
Vishnu Banna's avatar
Vishnu Banna committed
180
181
182
class Box(hyperparams.Config):
  box: List[int] = dataclasses.field(default=list)

183

Vishnu Banna's avatar
Vishnu Banna committed
184
185
@dataclasses.dataclass
class AnchorBoxes(hyperparams.Config):
186
  boxes: Optional[List[Box]] = None
Vishnu Banna's avatar
Vishnu Banna committed
187
188
189
  level_limits: Optional[List[int]] = None
  anchors_per_scale: int = 3

190
191
192
193
  generate_anchors: bool = False
  scaling_mode: str = 'sqrt'
  box_generation_mode: str = 'per_level'
  num_samples: int = 1024
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
194

Vishnu Banna's avatar
Vishnu Banna committed
195
  def get(self, min_level, max_level):
196
197
198
    """Distribute them in order to each level.

    Args:
Vishnu Banna's avatar
Vishnu Banna committed
199
200
      min_level: `int` the lowest output level.
      max_level: `int` the heighest output level.
201
    Returns:
Vishnu Banna's avatar
Vishnu Banna committed
202
      anchors_per_level: A `Dict[List[int]]` of the anchor boxes for each level.
203
      self.level_limits: A `List[int]` of the box size limits to link to each
Vishnu Banna's avatar
Vishnu Banna committed
204
205
        level under anchor free conditions.
    """
Vishnu Banna's avatar
Vishnu Banna committed
206
207
208
209
210
    if self.level_limits is None:
      boxes = [box.box for box in self.boxes]
    else:
      boxes = [[1.0, 1.0]] * ((max_level - min_level) + 1)
      self.anchors_per_scale = 1
211

Vishnu Banna's avatar
Vishnu Banna committed
212
213
214
215
216
217
218
    anchors_per_level = dict()
    start = 0
    for i in range(min_level, max_level + 1):
      anchors_per_level[str(i)] = boxes[start:start + self.anchors_per_scale]
      start += self.anchors_per_scale
    return anchors_per_level, self.level_limits

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
219
  def set_boxes(self, boxes):
220
    self.boxes = [Box(box=box) for box in boxes]
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
221

222

Vishnu Banna's avatar
Vishnu Banna committed
223
224
225
226
227
228
229
@dataclasses.dataclass
class Yolo(hyperparams.Config):
  input_size: Optional[List[int]] = dataclasses.field(
      default_factory=lambda: [512, 512, 3])
  backbone: backbones.Backbone = backbones.Backbone(
      type='darknet', darknet=backbones.Darknet(model_id='cspdarknet53'))
  decoder: decoders.Decoder = decoders.Decoder(
230
231
      type='yolo_decoder',
      yolo_decoder=decoders.YoloDecoder(version='v4', type='regular'))
Vishnu Banna's avatar
Vishnu Banna committed
232
233
234
235
236
237
238
239
240
241
242
243
  head: YoloHead = YoloHead()
  detection_generator: YoloDetectionGenerator = YoloDetectionGenerator()
  loss: YoloLoss = YoloLoss()
  norm_activation: common.NormActivation = common.NormActivation(
      activation='mish',
      use_sync_bn=True,
      norm_momentum=0.99,
      norm_epsilon=0.001)
  num_classes: int = 80
  anchor_boxes: AnchorBoxes = AnchorBoxes()
  darknet_based_model: bool = False

244

Vishnu Banna's avatar
Vishnu Banna committed
245
246
247
248
249
250
251
252
253
254
255
256
257
@dataclasses.dataclass
class YoloTask(cfg.TaskConfig):
  per_category_metrics: bool = False
  smart_bias_lr: float = 0.0
  model: Yolo = Yolo()
  train_data: DataConfig = DataConfig(is_training=True)
  validation_data: DataConfig = DataConfig(is_training=False)
  weight_decay: float = 0.0
  annotation_file: Optional[str] = None
  init_checkpoint: Optional[str] = None
  init_checkpoint_modules: Union[
      str, List[str]] = 'all'  # all, backbone, and/or decoder
  gradient_clip_norm: float = 0.0
Vishnu Banna's avatar
Vishnu Banna committed
258
  seed = GLOBAL_SEED
Vishnu Banna's avatar
Vishnu Banna committed
259
260
261
262
263


COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
Vishnu Banna's avatar
Vishnu Banna committed
264

Vishnu Banna's avatar
Vishnu Banna committed
265
266
267
268
269
270
271
272
273
274
275

@exp_factory.register_config_factory('yolo')
def yolo() -> cfg.ExperimentConfig:
  """Yolo general config."""
  return cfg.ExperimentConfig(
      task=YoloTask(),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

276

Vishnu Banna's avatar
Vishnu Banna committed
277
278
@exp_factory.register_config_factory('yolo_darknet')
def yolo_darknet() -> cfg.ExperimentConfig:
279
  """COCO object detection with YOLOv3 and v4."""
280
  train_batch_size = 256
Vishnu Banna's avatar
Vishnu Banna committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
  eval_batch_size = 8
  train_epochs = 300
  steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
  validation_interval = 5

  max_num_instances = 200
  config = cfg.ExperimentConfig(
      runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
      task=YoloTask(
          smart_bias_lr=0.1,
          init_checkpoint='',
          init_checkpoint_modules='backbone',
          annotation_file=None,
          weight_decay=0.0,
          model=Yolo(
296
              darknet_based_model=True,
Vishnu Banna's avatar
Vishnu Banna committed
297
              norm_activation=common.NormActivation(use_sync_bn=True),
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
              head=YoloHead(smart_bias=True),
              loss=YoloLoss(use_scaled_loss=False, update_on_repeat=True),
              anchor_boxes=AnchorBoxes(
                  anchors_per_scale=3,
                  boxes=[
                      Box(box=[12, 16]),
                      Box(box=[19, 36]),
                      Box(box=[40, 28]),
                      Box(box=[36, 75]),
                      Box(box=[76, 55]),
                      Box(box=[72, 146]),
                      Box(box=[142, 110]),
                      Box(box=[192, 243]),
                      Box(box=[459, 401])
                  ])),
Vishnu Banna's avatar
Vishnu Banna committed
313
          train_data=DataConfig(
314
              input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
Vishnu Banna's avatar
Vishnu Banna committed
315
316
              is_training=True,
              global_batch_size=train_batch_size,
317
              dtype='float32',
Vishnu Banna's avatar
Vishnu Banna committed
318
              parser=Parser(
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                  letter_box=False,
                  aug_rand_saturation=1.5,
                  aug_rand_brightness=1.5,
                  aug_rand_hue=0.1,
                  use_tie_breaker=True,
                  best_match_only=False,
                  anchor_thresh=0.4,
                  area_thresh=0.1,
                  max_num_instances=max_num_instances,
                  mosaic=Mosaic(
                      mosaic_frequency=0.75,
                      mixup_frequency=0.0,
                      mosaic_crop_mode='crop',
                      mosaic_center=0.2))),
Vishnu Banna's avatar
Vishnu Banna committed
333
          validation_data=DataConfig(
334
              input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
Vishnu Banna's avatar
Vishnu Banna committed
335
              is_training=False,
336
337
338
              global_batch_size=eval_batch_size,
              drop_remainder=True,
              dtype='float32',
Vishnu Banna's avatar
Vishnu Banna committed
339
              parser=Parser(
340
341
342
343
344
345
                  letter_box=False,
                  use_tie_breaker=True,
                  best_match_only=False,
                  anchor_thresh=0.4,
                  area_thresh=0.1,
                  max_num_instances=max_num_instances,
Vishnu Banna's avatar
Vishnu Banna committed
346
347
348
349
350
351
352
353
354
              ))),
      trainer=cfg.TrainerConfig(
          train_steps=train_epochs * steps_per_epoch,
          validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
          validation_interval=validation_interval * steps_per_epoch,
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
355
356
357
358
              'ema': {
                  'average_decay': 0.9998,
                  'trainable_weights_only': False,
                  'dynamic_decay': True,
Vishnu Banna's avatar
Vishnu Banna committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
              },
              'optimizer': {
                  'type': 'sgd_torch',
                  'sgd_torch': {
                      'momentum': 0.949,
                      'momentum_start': 0.949,
                      'nesterov': True,
                      'warmup_steps': 1000,
                      'weight_decay': 0.0005,
                  }
              },
              'learning_rate': {
                  'type': 'stepwise',
                  'stepwise': {
373
374
375
                      'boundaries': [
                          240 * steps_per_epoch
                      ],
Vishnu Banna's avatar
Vishnu Banna committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                      'values': [
                          0.00131 * train_batch_size / 64.0,
                          0.000131 * train_batch_size / 64.0,
                      ]
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': 1000,
                      'warmup_learning_rate': 0
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config


@exp_factory.register_config_factory('scaled_yolo')
def scaled_yolo() -> cfg.ExperimentConfig:
400
  """COCO object detection with YOLOv4-csp and v4."""
401
  train_batch_size = 256
Vishnu Banna's avatar
Vishnu Banna committed
402
403
404
  eval_batch_size = 8
  train_epochs = 300
  warmup_epochs = 3
405

Vishnu Banna's avatar
Vishnu Banna committed
406
407
408
409
410
411
412
413
414
  validation_interval = 5
  steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size

  max_num_instances = 300

  config = cfg.ExperimentConfig(
      runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
      task=YoloTask(
          smart_bias_lr=0.1,
415
          init_checkpoint_modules='',
Vishnu Banna's avatar
Vishnu Banna committed
416
417
418
          annotation_file=None,
          weight_decay=0.0,
          model=Yolo(
419
420
421
422
              darknet_based_model=False,
              norm_activation=common.NormActivation(
                  activation='mish',
                  use_sync_bn=True,
423
                  norm_epsilon=0.001,
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
                  norm_momentum=0.97),
              head=YoloHead(smart_bias=True),
              loss=YoloLoss(use_scaled_loss=True),
              anchor_boxes=AnchorBoxes(
                  anchors_per_scale=3,
                  boxes=[
                      Box(box=[12, 16]),
                      Box(box=[19, 36]),
                      Box(box=[40, 28]),
                      Box(box=[36, 75]),
                      Box(box=[76, 55]),
                      Box(box=[72, 146]),
                      Box(box=[142, 110]),
                      Box(box=[192, 243]),
                      Box(box=[459, 401])
                  ])),
Vishnu Banna's avatar
Vishnu Banna committed
440
          train_data=DataConfig(
441
              input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
Vishnu Banna's avatar
Vishnu Banna committed
442
443
444
445
              is_training=True,
              global_batch_size=train_batch_size,
              dtype='float32',
              parser=Parser(
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
                  aug_rand_saturation=0.7,
                  aug_rand_brightness=0.4,
                  aug_rand_hue=0.015,
                  letter_box=True,
                  use_tie_breaker=True,
                  best_match_only=True,
                  anchor_thresh=4.0,
                  random_pad=False,
                  area_thresh=0.1,
                  max_num_instances=max_num_instances,
                  mosaic=Mosaic(
                      mosaic_crop_mode='scale',
                      mosaic_frequency=1.0,
                      mixup_frequency=0.0,
                  ))),
Vishnu Banna's avatar
Vishnu Banna committed
461
          validation_data=DataConfig(
462
              input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
Vishnu Banna's avatar
Vishnu Banna committed
463
              is_training=False,
464
465
466
              global_batch_size=eval_batch_size,
              drop_remainder=True,
              dtype='float32',
Vishnu Banna's avatar
Vishnu Banna committed
467
              parser=Parser(
468
469
470
471
472
473
                  letter_box=True,
                  use_tie_breaker=True,
                  best_match_only=True,
                  anchor_thresh=4.0,
                  area_thresh=0.1,
                  max_num_instances=max_num_instances,
Vishnu Banna's avatar
Vishnu Banna committed
474
475
476
477
478
479
480
              ))),
      trainer=cfg.TrainerConfig(
          train_steps=train_epochs * steps_per_epoch,
          validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
          validation_interval=validation_interval * steps_per_epoch,
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
481
          checkpoint_interval=5 * steps_per_epoch,
Vishnu Banna's avatar
Vishnu Banna committed
482
          optimizer_config=optimization.OptimizationConfig({
483
484
485
486
              'ema': {
                  'average_decay': 0.9999,
                  'trainable_weights_only': False,
                  'dynamic_decay': True,
Vishnu Banna's avatar
Vishnu Banna committed
487
488
489
490
491
492
493
494
              },
              'optimizer': {
                  'type': 'sgd_torch',
                  'sgd_torch': {
                      'momentum': 0.937,
                      'momentum_start': 0.8,
                      'nesterov': True,
                      'warmup_steps': steps_per_epoch * warmup_epochs,
495
                      'weight_decay': 0.0005,
Vishnu Banna's avatar
Vishnu Banna committed
496
497
498
499
500
501
                  }
              },
              'learning_rate': {
                  'type': 'cosine',
                  'cosine': {
                      'initial_learning_rate': 0.01,
502
                      'alpha': 0.2,
Vishnu Banna's avatar
Vishnu Banna committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
                      'decay_steps': train_epochs * steps_per_epoch,
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': steps_per_epoch * warmup_epochs,
                      'warmup_learning_rate': 0
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

519
  return config