yolo.py 16.1 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains classes used to train Yolo."""

import collections
18
19
20
21
from typing import Optional

from absl import logging
import tensorflow as tf
Vishnu Banna's avatar
Vishnu Banna committed
22

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
23
from official.common import dataset_fn
Vishnu Banna's avatar
Vishnu Banna committed
24
from official.core import base_task
25
from official.core import config_definitions
Vishnu Banna's avatar
Vishnu Banna committed
26
27
28
from official.core import input_reader
from official.core import task_factory
from official.modeling import performance
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
30
31
32
33
34
35
36
37
from official.projects.yolo import optimization
from official.projects.yolo.configs import yolo as exp_cfg
from official.projects.yolo.dataloaders import tf_example_decoder
from official.projects.yolo.dataloaders import yolo_input
from official.projects.yolo.modeling import factory
from official.projects.yolo.ops import kmeans_anchors
from official.projects.yolo.ops import mosaic
from official.projects.yolo.ops import preprocessing_ops
from official.projects.yolo.tasks import task_utils
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
from official.vision.dataloaders import tfds_factory
Fan Yang's avatar
Fan Yang committed
39
from official.vision.dataloaders import tf_example_label_map_decoder
Abdullah Rashwan's avatar
Abdullah Rashwan committed
40
41
from official.vision.evaluation import coco_evaluator
from official.vision.ops import box_ops
Vishnu Banna's avatar
Vishnu Banna committed
42
43
44
45

OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig

46

Vishnu Banna's avatar
Vishnu Banna committed
47
48
49
50
51
52
53
54
55
@task_factory.register_task_cls(exp_cfg.YoloTask)
class YoloTask(base_task.Task):
  """A single-replica view of training procedure.

  YOLO task provides artifacts for training/evalution procedures, including
  loading/iterating over Datasets, initializing the model, calculating the loss,
  post-processing, and customized metrics with reduction.
  """

56
  def __init__(self, params, logging_dir: Optional[str] = None):
Vishnu Banna's avatar
Vishnu Banna committed
57
58
59
60
    super().__init__(params, logging_dir)
    self.coco_metric = None
    self._loss_fn = None
    self._model = None
Vishnu Banna's avatar
Vishnu Banna committed
61
    self._coco_91_to_80 = False
Vishnu Banna's avatar
Vishnu Banna committed
62
63
64
    self._metrics = []

    # globally set the random seed
Vishnu Banna's avatar
Vishnu Banna committed
65
    preprocessing_ops.set_random_seeds(seed=params.seed)
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
66
67
68

    if self.task_config.model.anchor_boxes.generate_anchors:
      self.generate_anchors()
Vishnu Banna's avatar
Vishnu Banna committed
69
70
    return

71
  def generate_anchors(self):
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
72
73
74
    """Generate Anchor boxes for an arbitrary object detection dataset."""
    input_size = self.task_config.model.input_size
    anchor_cfg = self.task_config.model.anchor_boxes
75
    backbone = self.task_config.model.backbone.get()
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
76

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
77
    dataset = self.task_config.train_data
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
78
    decoder = self._get_data_decoder(dataset)
79

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
80
81
82
83
84
85
86
87
88
89
90
    num_anchors = backbone.max_level - backbone.min_level + 1
    num_anchors *= anchor_cfg.anchors_per_scale

    gbs = dataset.global_batch_size
    dataset.global_batch_size = 1
    box_reader = kmeans_anchors.BoxGenInputReader(
        dataset,
        dataset_fn=tf.data.TFRecordDataset,
        decoder_fn=decoder.decode)

    boxes = box_reader.read(
91
92
93
94
95
96
        k=num_anchors,
        anchors_per_scale=anchor_cfg.anchors_per_scale,
        image_resolution=input_size,
        scaling_mode=anchor_cfg.scaling_mode,
        box_generation_mode=anchor_cfg.box_generation_mode,
        num_samples=anchor_cfg.num_samples)
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
97
98
99

    dataset.global_batch_size = gbs

100
101
102
103
104
    with open('anchors.txt', 'w') as f:
      f.write(f'input resolution: {input_size} \n boxes: \n {boxes}')
      logging.info('INFO: boxes will be saved to anchors.txt, mack sure to save'
                   'them and update the boxes feild in you yaml config file.')

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
105
106
107
    anchor_cfg.set_boxes(boxes)
    return boxes

Vishnu Banna's avatar
Vishnu Banna committed
108
109
110
111
112
113
114
115
116
117
  def build_model(self):
    """Build an instance of Yolo."""

    model_base_cfg = self.task_config.model
    l2_weight_decay = self.task_config.weight_decay / 2.0

    input_size = model_base_cfg.input_size.copy()
    input_specs = tf.keras.layers.InputSpec(shape=[None] + input_size)
    l2_regularizer = (
        tf.keras.regularizers.l2(l2_weight_decay) if l2_weight_decay else None)
118
119
    model, losses = factory.build_yolo(
        input_specs, model_base_cfg, l2_regularizer)
Vishnu Banna's avatar
Vishnu Banna committed
120
121
122
123
124
125

    # save for later usage within the task.
    self._loss_fn = losses
    self._model = model
    return model

Vishnu Banna's avatar
Vishnu Banna committed
126
  def _get_data_decoder(self, params):
Vishnu Banna's avatar
Vishnu Banna committed
127
128
    """Get a decoder object to decode the dataset."""
    if params.tfds_name:
Vishnu Banna's avatar
Vishnu Banna committed
129
      decoder = tfds_factory.get_detection_decoder(params.tfds_name)
Vishnu Banna's avatar
Vishnu Banna committed
130
131
132
    else:
      decoder_cfg = params.decoder.get()
      if params.decoder.type == 'simple_decoder':
Vishnu Banna's avatar
Vishnu Banna committed
133
        self._coco_91_to_80 = decoder_cfg.coco91_to_80
Vishnu Banna's avatar
Vishnu Banna committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        decoder = tf_example_decoder.TfExampleDecoder(
            coco91_to_80=decoder_cfg.coco91_to_80,
            regenerate_source_id=decoder_cfg.regenerate_source_id)
      elif params.decoder.type == 'label_map_decoder':
        decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
            label_map=decoder_cfg.label_map,
            regenerate_source_id=decoder_cfg.regenerate_source_id)
      else:
        raise ValueError('Unknown decoder type: {}!'.format(
            params.decoder.type))
    return decoder

  def build_inputs(self, params, input_context=None):
    """Build input dataset."""
    model = self.task_config.model

    # get anchor boxes dict based on models min and max level
    backbone = model.backbone.get()
    anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level,
                                                       backbone.max_level)

Vishnu Banna's avatar
Vishnu Banna committed
155
    params.seed = self.task_config.seed
Vishnu Banna's avatar
Vishnu Banna committed
156
157
158
159
160
161
162
163
164
165
166
167
    # set shared patamters between mosaic and yolo_input
    base_config = dict(
        letter_box=params.parser.letter_box,
        aug_rand_translate=params.parser.aug_rand_translate,
        aug_rand_angle=params.parser.aug_rand_angle,
        aug_rand_perspective=params.parser.aug_rand_perspective,
        area_thresh=params.parser.area_thresh,
        random_flip=params.parser.random_flip,
        seed=params.seed,
    )

    # get the decoder
Vishnu Banna's avatar
Vishnu Banna committed
168
    decoder = self._get_data_decoder(params)
Vishnu Banna's avatar
Vishnu Banna committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    # init Mosaic
    sample_fn = mosaic.Mosaic(
        output_size=model.input_size,
        mosaic_frequency=params.parser.mosaic.mosaic_frequency,
        mixup_frequency=params.parser.mosaic.mixup_frequency,
        jitter=params.parser.mosaic.jitter,
        mosaic_center=params.parser.mosaic.mosaic_center,
        mosaic_crop_mode=params.parser.mosaic.mosaic_crop_mode,
        aug_scale_min=params.parser.mosaic.aug_scale_min,
        aug_scale_max=params.parser.mosaic.aug_scale_max,
        **base_config)

    # init Parser
    parser = yolo_input.Parser(
        output_size=model.input_size,
        anchors=anchor_dict,
        use_tie_breaker=params.parser.use_tie_breaker,
        jitter=params.parser.jitter,
        aug_scale_min=params.parser.aug_scale_min,
        aug_scale_max=params.parser.aug_scale_max,
        aug_rand_hue=params.parser.aug_rand_hue,
        aug_rand_saturation=params.parser.aug_rand_saturation,
        aug_rand_brightness=params.parser.aug_rand_brightness,
        max_num_instances=params.parser.max_num_instances,
        scale_xy=model.detection_generator.scale_xy.get(),
        expanded_strides=model.detection_generator.path_scales.get(),
        darknet=model.darknet_based_model,
        best_match_only=params.parser.best_match_only,
        anchor_t=params.parser.anchor_thresh,
        random_pad=params.parser.random_pad,
        level_limits=level_limits,
        dtype=params.dtype,
        **base_config)

    # init the dataset reader
    reader = input_reader.InputReader(
        params,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
207
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
Vishnu Banna's avatar
Vishnu Banna committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        decoder_fn=decoder.decode,
        sample_fn=sample_fn.mosaic_fn(is_training=params.is_training),
        parser_fn=parser.parse_fn(params.is_training))
    dataset = reader.read(input_context=input_context)
    return dataset

  def build_metrics(self, training=True):
    """Build detection metrics."""
    metrics = []

    backbone = self.task_config.model.backbone.get()
    metric_names = collections.defaultdict(list)
    for key in range(backbone.min_level, backbone.max_level + 1):
      key = str(key)
      metric_names[key].append('loss')
223
224
      metric_names[key].append('avg_iou')
      metric_names[key].append('avg_obj')
Vishnu Banna's avatar
Vishnu Banna committed
225
226
227
228
229

    metric_names['net'].append('box')
    metric_names['net'].append('class')
    metric_names['net'].append('conf')

230
    for _, key in enumerate(metric_names.keys()):
Vishnu Banna's avatar
Vishnu Banna committed
231
      metrics.append(task_utils.ListMetrics(metric_names[key], name=key))
Vishnu Banna's avatar
Vishnu Banna committed
232
233
234

    self._metrics = metrics
    if not training:
Vishnu Banna's avatar
Vishnu Banna committed
235
236
237
      annotation_file = self.task_config.annotation_file
      if self._coco_91_to_80:
        annotation_file = None
Vishnu Banna's avatar
Vishnu Banna committed
238
      self.coco_metric = coco_evaluator.COCOEvaluator(
Vishnu Banna's avatar
Vishnu Banna committed
239
          annotation_file=annotation_file,
Vishnu Banna's avatar
Vishnu Banna committed
240
241
242
243
244
245
246
247
248
249
250
          include_mask=False,
          need_rescale_bboxes=False,
          per_category_metrics=self._task_config.per_category_metrics)

    return metrics

  def build_losses(self, outputs, labels, aux_losses=None):
    """Build YOLO losses."""
    return self._loss_fn(labels, outputs)

  def train_step(self, inputs, model, optimizer, metrics=None):
251
252
253
254
    """Train Step.

    Forward step and backwards propagate the model.

Vishnu Banna's avatar
Vishnu Banna committed
255
256
257
258
259
    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.
260

Vishnu Banna's avatar
Vishnu Banna committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    Returns:
      A dictionary of logs.
    """
    image, label = inputs

    with tf.GradientTape(persistent=False) as tape:
      # Compute a prediction
      y_pred = model(image, training=True)

      # Cast to float32 for gradietn computation
      y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)

      # Get the total loss
      (scaled_loss, metric_loss,
       loss_metrics) = self.build_losses(y_pred['raw_output'], label)

      # Scale the loss for numerical stability
      if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
        scaled_loss = optimizer.get_scaled_loss(scaled_loss)

    # Compute the gradient
    train_vars = model.trainable_variables
    gradients = tape.gradient(scaled_loss, train_vars)

    # Get unscaled loss if we are using the loss scale optimizer on fp16
    if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
      gradients = optimizer.get_unscaled_gradients(gradients)

    # Apply gradients to the model
    optimizer.apply_gradients(zip(gradients, train_vars))
    logs = {self.loss: metric_loss}

    # Compute all metrics
    if metrics:
      for m in metrics:
        m.update_state(loss_metrics[m.name])
        logs.update({m.name: m.result()})
    return logs

300
  def _reorg_boxes(self, boxes, info, num_detections):
Vishnu Banna's avatar
Vishnu Banna committed
301
302
303
304
305
    """Scale and Clean boxes prior to Evaluation."""
    mask = tf.sequence_mask(num_detections, maxlen=tf.shape(boxes)[1])
    mask = tf.cast(tf.expand_dims(mask, axis=-1), boxes.dtype)

    # Denormalize the boxes by the shape of the image
306
307
308
309
310
    inshape = tf.expand_dims(info[:, 1, :], axis=1)
    ogshape = tf.expand_dims(info[:, 0, :], axis=1)
    scale = tf.expand_dims(info[:, 2, :], axis=1)
    offset = tf.expand_dims(info[:, 3, :], axis=1)

Vishnu Banna's avatar
Vishnu Banna committed
311
    boxes = box_ops.denormalize_boxes(boxes, inshape)
312
313
314
315
    boxes = box_ops.clip_boxes(boxes, inshape)
    boxes += tf.tile(offset, [1, 1, 2])
    boxes /= tf.tile(scale, [1, 1, 2])
    boxes = box_ops.clip_boxes(boxes, ogshape)
Vishnu Banna's avatar
Vishnu Banna committed
316
317
318
319
320
321
322
323

    # Mask the boxes for usage
    boxes *= mask
    boxes += (mask - 1)
    return boxes

  def validation_step(self, inputs, model, metrics=None):
    """Validatation step.
324

Vishnu Banna's avatar
Vishnu Banna committed
325
326
327
328
    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.
329

Vishnu Banna's avatar
Vishnu Banna committed
330
331
332
333
334
335
336
337
338
339
340
341
342
    Returns:
      A dictionary of logs.
    """
    image, label = inputs

    # Step the model once
    y_pred = model(image, training=False)
    y_pred = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), y_pred)
    (_, metric_loss, loss_metrics) = self.build_losses(y_pred['raw_output'],
                                                       label)
    logs = {self.loss: metric_loss}

    # Reorganize and rescale the boxes
343
344
    info = label['groundtruths']['image_info']
    boxes = self._reorg_boxes(y_pred['bbox'], info, y_pred['num_detections'])
Vishnu Banna's avatar
Vishnu Banna committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

    # Build the input for the coc evaluation metric
    coco_model_outputs = {
        'detection_boxes': boxes,
        'detection_scores': y_pred['confidence'],
        'detection_classes': y_pred['classes'],
        'num_detections': y_pred['num_detections'],
        'source_id': label['groundtruths']['source_id'],
        'image_info': label['groundtruths']['image_info']
    }

    # Compute all metrics
    if metrics:
      logs.update(
          {self.coco_metric.name: (label['groundtruths'], coco_model_outputs)})
      for m in metrics:
        m.update_state(loss_metrics[m.name])
        logs.update({m.name: m.result()})
    return logs

  def aggregate_logs(self, state=None, step_outputs=None):
    """Get Metric Results."""
    if not state:
      self.coco_metric.reset_states()
      state = self.coco_metric
    self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
                                  step_outputs[self.coco_metric.name][1])
    return state

  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
    """Reduce logs and remove unneeded items. Update with COCO results."""
    res = self.coco_metric.result()
    return res

  def initialize(self, model: tf.keras.Model):
    """Loading pretrained checkpoint."""

    if not self.task_config.init_checkpoint:
383
      logging.info('Training from Scratch.')
Vishnu Banna's avatar
Vishnu Banna committed
384
385
386
387
388
389
390
391
      return

    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)

    # Restoring checkpoint.
    if self.task_config.init_checkpoint_modules == 'all':
Vishnu Banna's avatar
Vishnu Banna committed
392
393
      ckpt = tf.train.Checkpoint(**model.checkpoint_items)
      status = ckpt.read(ckpt_dir_or_file)
Vishnu Banna's avatar
Vishnu Banna committed
394
395
      status.expect_partial().assert_existing_objects_matched()
    else:
Vishnu Banna's avatar
Vishnu Banna committed
396
397
398
399
400
401
402
403
404
      ckpt_items = {}
      if 'backbone' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(backbone=model.backbone)
      if 'decoder' in self.task_config.init_checkpoint_modules:
        ckpt_items.update(decoder=model.decoder)

      ckpt = tf.train.Checkpoint(**ckpt_items)
      status = ckpt.read(ckpt_dir_or_file)
      status.expect_partial().assert_existing_objects_matched()
Vishnu Banna's avatar
Vishnu Banna committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

    logging.info('Finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)

  def create_optimizer(self,
                       optimizer_config: OptimizationConfig,
                       runtime_config: Optional[RuntimeConfig] = None):
    """Creates an TF optimizer from configurations.

    Args:
      optimizer_config: the parameters of the Optimization settings.
      runtime_config: the parameters of the runtime.

    Returns:
      A tf.optimizers.Optimizer object.
    """
    opt_factory = optimization.YoloOptimizerFactory(optimizer_config)
422
    # pylint: disable=protected-access
Vishnu Banna's avatar
Vishnu Banna committed
423
424
425
426
    ema = opt_factory._use_ema
    opt_factory._use_ema = False

    opt_type = opt_factory._optimizer_type
427
    if opt_type == 'sgd_torch':
Vishnu Banna's avatar
Vishnu Banna committed
428
429
430
      optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
      optimizer.set_bias_lr(
          opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr))
431
      optimizer.search_and_set_variable_groups(self._model.trainable_variables)
Vishnu Banna's avatar
Vishnu Banna committed
432
433
434
435
436
    else:
      optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
    opt_factory._use_ema = ema

    if ema:
437
      logging.info('EMA is enabled.')
Vishnu Banna's avatar
Vishnu Banna committed
438
    optimizer = opt_factory.add_ema(optimizer)
Vishnu Banna's avatar
Vishnu Banna committed
439

440
441
    # pylint: enable=protected-access

Vishnu Banna's avatar
Vishnu Banna committed
442
    if runtime_config and runtime_config.loss_scale:
443
      use_float16 = runtime_config.mixed_precision_dtype == 'float16'
Vishnu Banna's avatar
Vishnu Banna committed
444
445
446
447
448
      optimizer = performance.configure_optimizer(
          optimizer,
          use_float16=use_float16,
          loss_scale=runtime_config.loss_scale)

449
    return optimizer