yolo_loss.py 31 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

Vishnu Banna's avatar
Vishnu Banna committed
15
"""Yolo Loss function."""
Vishnu Banna's avatar
Vishnu Banna committed
16
import abc
17
18
19
20
import collections
import functools

import tensorflow as tf
Vishnu Banna's avatar
Vishnu Banna committed
21

Abdullah Rashwan's avatar
Abdullah Rashwan committed
22
23
24
from official.projects.yolo.ops import box_ops
from official.projects.yolo.ops import loss_utils
from official.projects.yolo.ops import math_ops
Vishnu Banna's avatar
Vishnu Banna committed
25
26
27


class YoloLossBase(object, metaclass=abc.ABCMeta):
28
29
30
  """Parameters for the YOLO loss functions used at each detection generator.

  This base class implements the base functionality required to implement a Yolo
Vishnu Banna's avatar
Vishnu Banna committed
31
32
  Loss function.
  """
Vishnu Banna's avatar
Vishnu Banna committed
33
34
35
36
37
38
39

  def __init__(self,
               classes,
               anchors,
               path_stride=1,
               ignore_thresh=0.7,
               truth_thresh=1.0,
40
               loss_type='ciou',
Vishnu Banna's avatar
Vishnu Banna committed
41
42
               iou_normalizer=1.0,
               cls_normalizer=1.0,
Vishnu Banna's avatar
Vishnu Banna committed
43
               object_normalizer=1.0,
Vishnu Banna's avatar
Vishnu Banna committed
44
45
46
               label_smoothing=0.0,
               objectness_smooth=True,
               update_on_repeat=False,
47
               box_type='original',
Vishnu Banna's avatar
Vishnu Banna committed
48
49
               scale_x_y=1.0,
               max_delta=10):
50
    """Loss Function Initialization.
Vishnu Banna's avatar
Vishnu Banna committed
51
52

    Args:
53
54
55
56
57
      classes: `int` for the number of classes
      anchors: `List[List[int]]` for the anchor boxes that are used in the model
        at all levels. For anchor free prediction set the anchor list to be the
        same as the image resolution.
      path_stride: `int` for how much to scale this level to get the orginal
Vishnu Banna's avatar
Vishnu Banna committed
58
        input shape.
59
      ignore_thresh: `float` for the IOU value over which the loss is not
Vishnu Banna's avatar
Vishnu Banna committed
60
        propagated, and a detection is assumed to have been made.
61
      truth_thresh: `float` for the IOU value over which the loss is propagated
Vishnu Banna's avatar
Vishnu Banna committed
62
        despite a detection being made.
63
64
65
      loss_type: `str` for the typeof iou loss to use with in {ciou, diou, giou,
        iou}.
      iou_normalizer: `float` for how much to scale the loss on the IOU or the
Vishnu Banna's avatar
Vishnu Banna committed
66
67
        boxes.
      cls_normalizer: `float` for how much to scale the loss on the classes.
68
69
      object_normalizer: `float` for how much to scale loss on the detection
        map.
70
71
      label_smoothing: `float` for how much to smooth the loss on the classes.
      objectness_smooth: `float` for how much to smooth the loss on the
Vishnu Banna's avatar
Vishnu Banna committed
72
73
        detection map.
      update_on_repeat: `bool` for whether to replace with the newest or the
74
75
76
77
78
79
80
81
        best value when an index is consumed by multiple objects.
      box_type: `bool` for which scaling type to use.
      scale_x_y: dictionary `float` values inidcating how far each pixel can see
        outside of its containment of 1.0. a value of 1.2 indicates there is a
        20% extended radius around each pixel that this specific pixel can
        predict values for a center at. the center can range from 0 - value/2 to
        1 + value/2, this value is set in the yolo filter, and resused here.
        there should be one value for scale_xy for each level from min_level to
Vishnu Banna's avatar
Vishnu Banna committed
82
        max_level.
83
      max_delta: gradient clipping to apply to the box loss.
Vishnu Banna's avatar
Vishnu Banna committed
84
85
    """
    self._loss_type = loss_type
Vishnu Banna's avatar
Vishnu Banna committed
86
    self._classes = classes
Vishnu Banna's avatar
Vishnu Banna committed
87
    self._num = tf.cast(len(anchors), dtype=tf.int32)
Vishnu Banna's avatar
Vishnu Banna committed
88
89
90
91
92
93
    self._truth_thresh = truth_thresh
    self._ignore_thresh = ignore_thresh
    self._anchors = anchors

    self._iou_normalizer = iou_normalizer
    self._cls_normalizer = cls_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
94
    self._object_normalizer = object_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    self._scale_x_y = scale_x_y
    self._max_delta = max_delta

    self._label_smoothing = tf.cast(label_smoothing, tf.float32)
    self._objectness_smooth = float(objectness_smooth)
    self._update_on_repeat = update_on_repeat
    self._box_type = box_type
    self._path_stride = path_stride

    box_kwargs = dict(
        stride=self._path_stride,
        scale_xy=self._scale_x_y,
        box_type=self._box_type,
        max_delta=self._max_delta)
109
110
    self._decode_boxes = functools.partial(
        loss_utils.get_predicted_box, **box_kwargs)
Vishnu Banna's avatar
Vishnu Banna committed
111

Vishnu Banna's avatar
Vishnu Banna committed
112
    self._search_pairs = lambda *args: (None, None, None, None)
Vishnu Banna's avatar
Vishnu Banna committed
113
114
115
    self._build_per_path_attributes()

  def box_loss(self, true_box, pred_box, darknet=False):
Vishnu Banna's avatar
Vishnu Banna committed
116
    """Call iou function and use it to compute the loss for the box maps."""
117
    if self._loss_type == 'giou':
Vishnu Banna's avatar
Vishnu Banna committed
118
      iou, liou = box_ops.compute_giou(true_box, pred_box)
119
    elif self._loss_type == 'ciou':
Vishnu Banna's avatar
Vishnu Banna committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
      iou, liou = box_ops.compute_ciou(true_box, pred_box, darknet=darknet)
    else:
      liou = iou = box_ops.compute_iou(true_box, pred_box)
    loss_box = 1 - liou
    return iou, liou, loss_box

  def _tiled_global_box_search(self,
                               pred_boxes,
                               pred_classes,
                               boxes,
                               classes,
                               true_conf,
                               smoothed,
                               scale=None):
Vishnu Banna's avatar
Vishnu Banna committed
134
    """Search of all groundtruths to associate groundtruths to predictions."""
Vishnu Banna's avatar
Vishnu Banna committed
135

Vishnu Banna's avatar
Vishnu Banna committed
136
137
138
139
    boxes = box_ops.yxyx_to_xcycwh(boxes)

    if scale is not None:
      boxes = boxes * tf.cast(tf.stop_gradient(scale), boxes.dtype)
Vishnu Banna's avatar
Vishnu Banna committed
140

Vishnu Banna's avatar
Vishnu Banna committed
141
142
    # Search all predictions against ground truths to find mathcing boxes for
    # each pixel.
143
144
    _, _, iou_max, _ = self._search_pairs(pred_boxes, pred_classes, boxes,
                                          classes)
Vishnu Banna's avatar
Vishnu Banna committed
145

146
147
148
    if iou_max is None:
      return true_conf, tf.ones_like(true_conf)

Vishnu Banna's avatar
Vishnu Banna committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    # Find the exact indexes to ignore and keep.
    ignore_mask = tf.cast(iou_max < self._ignore_thresh, pred_boxes.dtype)
    iou_mask = iou_max > self._ignore_thresh

    if not smoothed:
      # Ignore all pixels where a box was not supposed to be predicted but a
      # high confidence box was predicted.
      obj_mask = true_conf + (1 - true_conf) * ignore_mask
    else:
      # Replace pixels in the tre confidence map with the max iou predicted
      # with in that cell.
      obj_mask = tf.ones_like(true_conf)
      iou_ = (1 - self._objectness_smooth) + self._objectness_smooth * iou_max
      iou_ = tf.where(iou_max > 0, iou_, tf.zeros_like(iou_))
      true_conf = tf.where(iou_mask, iou_, true_conf)

    # Stop gradient so while loop is not tracked.
    obj_mask = tf.stop_gradient(obj_mask)
    true_conf = tf.stop_gradient(true_conf)
    return true_conf, obj_mask

  def __call__(self, true_counts, inds, y_true, boxes, classes, y_pred):
Vishnu Banna's avatar
Vishnu Banna committed
171
    """Call function to compute the loss and a set of metrics per FPN level.
172
173
174
175
176
177
178
179

    Args:
      true_counts: `Tensor` of shape [batchsize, height, width, num_anchors]
        represeneting how many boxes are in a given pixel [j, i] in the output
        map.
      inds: `Tensor` of shape [batchsize, None, 3] indicating the location [j,
        i] that a given box is associatied with in the FPN prediction map.
      y_true: `Tensor` of shape [batchsize, None, 8] indicating the actual box
Vishnu Banna's avatar
Vishnu Banna committed
180
        associated with each index in the inds tensor list.
181
182
183
184
      boxes: `Tensor` of shape [batchsize, None, 4] indicating the original
        ground truth boxes for each image as they came from the decoder used for
        bounding box search.
      classes: `Tensor` of shape [batchsize, None, 1] indicating the original
Vishnu Banna's avatar
Vishnu Banna committed
185
186
        ground truth classes for each image as they came from the decoder used
        for bounding box search.
187
188
      y_pred: `Tensor` of shape [batchsize, height, width, output_depth] holding
        the models output at a specific FPN level.
Vishnu Banna's avatar
Vishnu Banna committed
189

190
    Returns:
Vishnu Banna's avatar
Vishnu Banna committed
191
192
193
194
      loss: `float` for the actual loss.
      box_loss: `float` loss on the boxes used for metrics.
      conf_loss: `float` loss on the confidence used for metrics.
      class_loss: `float` loss on the classes used for metrics.
195
196
197
198
      avg_iou: `float` metric for the average iou between predictions and ground
        truth.
      avg_obj: `float` metric for the average confidence of the model for
        predictions.
Vishnu Banna's avatar
Vishnu Banna committed
199
200
    """
    (loss, box_loss, conf_loss, class_loss, mean_loss, iou, pred_conf, ind_mask,
201
202
     grid_mask) = self._compute_loss(true_counts, inds, y_true, boxes, classes,
                                     y_pred)
Vishnu Banna's avatar
Vishnu Banna committed
203
204
205
206

    # Metric compute using done here to save time and resources.
    sigmoid_conf = tf.stop_gradient(tf.sigmoid(pred_conf))
    iou = tf.stop_gradient(iou)
Vishnu Banna's avatar
Vishnu Banna committed
207
    avg_iou = loss_utils.average_iou(
Vishnu Banna's avatar
Vishnu Banna committed
208
        loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), iou))
Vishnu Banna's avatar
Vishnu Banna committed
209
    avg_obj = loss_utils.average_iou(
210
        tf.squeeze(sigmoid_conf, axis=-1) * grid_mask)
Vishnu Banna's avatar
Vishnu Banna committed
211
212
213
214
215
    return (loss, box_loss, conf_loss, class_loss, mean_loss,
            tf.stop_gradient(avg_iou), tf.stop_gradient(avg_obj))

  @abc.abstractmethod
  def _build_per_path_attributes(self):
Vishnu Banna's avatar
Vishnu Banna committed
216
    """Additional initialization required for each YOLO loss version."""
Vishnu Banna's avatar
Vishnu Banna committed
217
218
219
    ...

  @abc.abstractmethod
Vishnu Banna's avatar
Vishnu Banna committed
220
  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
Vishnu Banna's avatar
Vishnu Banna committed
221
222
223
    """The actual logic to apply to the raw model for optimization."""
    ...

224
225
  def post_path_aggregation(self, loss, box_loss, conf_loss, class_loss,
                            ground_truths, predictions):  # pylint:disable=unused-argument
226
227
    """This method allows for post processing of a loss value.

Vishnu Banna's avatar
Vishnu Banna committed
228
    After the loss has been aggregated across all the FPN levels some post
229
    proceessing may need to occur to poroperly scale the loss. The default
230
231
232
    behavior is to pass the loss through with no alterations. Passing the
    individual losses for each mask will allow for aggeregation of loss across
    paths for some losses.
233
234

    Args:
Vishnu Banna's avatar
Vishnu Banna committed
235
      loss: `tf.float` scalar for the actual loss.
236
237
238
      box_loss: `tf.float` for the loss on the boxs only.
      conf_loss: `tf.float` for the loss on the confidences only.
      class_loss: `tf.float` for the loss on the classes only.
Vishnu Banna's avatar
Vishnu Banna committed
239
240
241
      ground_truths: `Dict` holding all the ground truth tensors.
      predictions: `Dict` holding all the predicted values.

242
    Returns:
Vishnu Banna's avatar
Vishnu Banna committed
243
      loss: `tf.float` scalar for the scaled loss.
Vishnu Banna's avatar
Vishnu Banna committed
244
      scale: `tf.float` how much the loss was scaled by.
Vishnu Banna's avatar
Vishnu Banna committed
245
    """
246
247
248
249
250
    del box_loss
    del conf_loss
    del class_loss
    del ground_truths
    del predictions
Vishnu Banna's avatar
Vishnu Banna committed
251
    return loss, tf.ones_like(loss)
Vishnu Banna's avatar
Vishnu Banna committed
252
253
254
255
256
257
258
259
260

  @abc.abstractmethod
  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
    """This controls how the loss should be aggregated across replicas."""
    ...


@tf.custom_gradient
def grad_sigmoid(values):
Vishnu Banna's avatar
Vishnu Banna committed
261
  """This function scales the gradient as if a signmoid was applied.
262
263
264
265
266
267
268
269
270
271
272
273
274

  This is used in the Darknet Loss when the choosen box type is the scaled
  coordinate type. This function is used to match the propagated gradient to
  match that of the Darkent Yolov4 model. This is an Identity operation that
  allows us to add some extra steps to the back propagation.

  Args:
    values: A tensor of any shape.

  Returns:
    values: The unaltered input tensor.
    delta: A custom gradient function that adds the sigmoid step to the
      backpropagation.
Vishnu Banna's avatar
Vishnu Banna committed
275
  """
276

Vishnu Banna's avatar
Vishnu Banna committed
277
278
279
280
281
282
283
284
  def delta(dy):
    t = tf.math.sigmoid(values)
    return dy * t * (1 - t)

  return values, delta


class DarknetLoss(YoloLossBase):
Vishnu Banna's avatar
Vishnu Banna committed
285
  """This class implements the full logic for the standard Yolo models."""
Vishnu Banna's avatar
Vishnu Banna committed
286
287

  def _build_per_path_attributes(self):
Vishnu Banna's avatar
Vishnu Banna committed
288
    """Paramterization of pair wise search and grid generators.
289
290
291
292

    Objects created here are used for box decoding and dynamic ground truth
    association.
    """
Vishnu Banna's avatar
Vishnu Banna committed
293
294
295
296
297
298
    self._anchor_generator = loss_utils.GridGenerator(
        anchors=self._anchors,
        scale_anchors=self._path_stride)

    if self._ignore_thresh > 0.0:
      self._search_pairs = loss_utils.PairWiseSearch(
299
          iou_type='iou', any_match=True, min_conf=0.25)
Vishnu Banna's avatar
Vishnu Banna committed
300
301
    return

Vishnu Banna's avatar
Vishnu Banna committed
302
  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
Vishnu Banna's avatar
Vishnu Banna committed
303
    """Per FPN path loss logic used for Yolov3, Yolov4, and Yolo-Tiny."""
304
    if self._box_type == 'scaled':
Vishnu Banna's avatar
Vishnu Banna committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
      # Darknet Model Propagates a sigmoid once in back prop so we replicate
      # that behaviour
      y_pred = grad_sigmoid(y_pred)

    # Generate and store constants and format output.
    shape = tf.shape(true_counts)
    batch_size, width, height, num = shape[0], shape[1], shape[2], shape[3]
    fwidth = tf.cast(width, tf.float32)
    fheight = tf.cast(height, tf.float32)
    grid_points, anchor_grid = self._anchor_generator(
        width, height, batch_size, dtype=tf.float32)

    # Cast all input compontnts to float32 and stop gradient to save memory.
    boxes = tf.stop_gradient(tf.cast(boxes, tf.float32))
    classes = tf.stop_gradient(tf.cast(classes, tf.float32))
    y_true = tf.stop_gradient(tf.cast(y_true, tf.float32))
    true_counts = tf.stop_gradient(tf.cast(true_counts, tf.float32))
    true_conf = tf.stop_gradient(tf.clip_by_value(true_counts, 0.0, 1.0))
    grid_points = tf.stop_gradient(grid_points)
    anchor_grid = tf.stop_gradient(anchor_grid)

Yulv-git's avatar
Yulv-git committed
326
    # Split all the ground truths to use as separate items in loss computation.
Vishnu Banna's avatar
Vishnu Banna committed
327
    (true_box, ind_mask, true_class) = tf.split(y_true, [4, 1, 1], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    true_conf = tf.squeeze(true_conf, axis=-1)
    true_class = tf.squeeze(true_class, axis=-1)
    grid_mask = true_conf

    # Splits all predictions.
    y_pred = tf.cast(
        tf.reshape(y_pred, [batch_size, width, height, num, -1]), tf.float32)
    pred_box, pred_conf, pred_class = tf.split(y_pred, [4, 1, -1], axis=-1)

    # Decode the boxes to be used for loss compute.
    _, _, pred_box = self._decode_boxes(
        fwidth, fheight, pred_box, anchor_grid, grid_points, darknet=True)

    # If the ignore threshold is enabled, search all boxes ignore all
    # IOU valeus larger than the ignore threshold that are not in the
    # noted ground truth list.
    if self._ignore_thresh != 0.0:
      (true_conf, obj_mask) = self._tiled_global_box_search(
          pred_box,
          tf.stop_gradient(tf.sigmoid(pred_class)),
          boxes,
          classes,
          true_conf,
          smoothed=self._objectness_smooth > 0)

    # Build the one hot class list that are used for class loss.
    true_class = tf.one_hot(
        tf.cast(true_class, tf.int32),
        depth=tf.shape(pred_class)[-1],
        dtype=pred_class.dtype)
358
    true_class = tf.stop_gradient(loss_utils.apply_mask(ind_mask, true_class))
Vishnu Banna's avatar
Vishnu Banna committed
359
360

    # Reorganize the one hot class list as a grid.
361
362
363
    true_class_grid = loss_utils.build_grid(
        inds, true_class, pred_class, ind_mask, update=False)
    true_class_grid = tf.stop_gradient(true_class_grid)
Vishnu Banna's avatar
Vishnu Banna committed
364
365
366

    # Use the class mask to find the number of objects located in
    # each predicted grid cell/pixel.
367
    counts = true_class_grid
Vishnu Banna's avatar
Vishnu Banna committed
368
369
370
371
372
373
374
375
376
377
378
379
380
    counts = tf.reduce_sum(counts, axis=-1, keepdims=True)
    reps = tf.gather_nd(counts, inds, batch_dims=1)
    reps = tf.squeeze(reps, axis=-1)
    reps = tf.stop_gradient(tf.where(reps == 0.0, tf.ones_like(reps), reps))

    # Compute the loss for only the cells in which the boxes are located.
    pred_box = loss_utils.apply_mask(ind_mask,
                                     tf.gather_nd(pred_box, inds, batch_dims=1))
    iou, _, box_loss = self.box_loss(true_box, pred_box, darknet=True)
    box_loss = loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), box_loss)
    box_loss = math_ops.divide_no_nan(box_loss, reps)
    box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype)

381
    if self._update_on_repeat:
Vishnu Banna's avatar
Vishnu Banna committed
382
      # Converts list of gound truths into a grid where repeated values
383
384
      # are replaced by the most recent value. So some class identities may
      # get lost but the loss computation will be more stable. Results are
Vishnu Banna's avatar
Vishnu Banna committed
385
      # more consistent.
386

387
388
389
390
391
392
393
394
395
396
397
398
      # Compute the sigmoid binary cross entropy for the class maps.
      class_loss = tf.reduce_mean(
          loss_utils.sigmoid_bce(
              tf.expand_dims(true_class_grid, axis=-1),
              tf.expand_dims(pred_class, axis=-1), self._label_smoothing),
          axis=-1)

      # Apply normalization to the class losses.
      if self._cls_normalizer < 1.0:
        # Build a mask based on the true class locations.
        cls_norm_mask = true_class_grid
        # Apply the classes weight to class indexes were one_hot is one.
399
400
        class_loss *= ((1 - cls_norm_mask) +
                       cls_norm_mask * self._cls_normalizer)
401
402
403
404
405
406
407
408

      # Mask to the class loss and compute the sum over all the objects.
      class_loss = tf.reduce_sum(class_loss, axis=-1)
      class_loss = loss_utils.apply_mask(grid_mask, class_loss)
      class_loss = math_ops.rm_nan_inf(class_loss, val=0.0)
      class_loss = tf.cast(
          tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype)
    else:
409
410
411
      # Computes the loss while keeping the structure as a list in
      # order to ensure all objects are considered. In some cases can
      # make training more unstable but may also return higher APs.
412
413
414
      pred_class = loss_utils.apply_mask(
          ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))
      class_loss = tf.keras.losses.binary_crossentropy(
415
416
          tf.expand_dims(true_class, axis=-1),
          tf.expand_dims(pred_class, axis=-1),
417
418
419
          label_smoothing=self._label_smoothing,
          from_logits=True)
      class_loss = loss_utils.apply_mask(ind_mask, class_loss)
420
421
      class_loss = math_ops.divide_no_nan(class_loss,
                                          tf.expand_dims(reps, axis=-1))
Vishnu Banna's avatar
Vishnu Banna committed
422
423
424
      class_loss = tf.cast(
          tf.reduce_sum(class_loss, axis=(1, 2)), dtype=y_pred.dtype)
      class_loss *= self._cls_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
425
426
427

    # Compute the sigmoid binary cross entropy for the confidence maps.
    bce = tf.reduce_mean(
Vishnu Banna's avatar
Vishnu Banna committed
428
        loss_utils.sigmoid_bce(
Vishnu Banna's avatar
Vishnu Banna committed
429
430
431
432
433
434
435
436
437
438
            tf.expand_dims(true_conf, axis=-1), pred_conf, 0.0),
        axis=-1)

    # Mask the confidence loss and take the sum across all the grid cells.
    if self._ignore_thresh != 0.0:
      bce = loss_utils.apply_mask(obj_mask, bce)
    conf_loss = tf.cast(tf.reduce_sum(bce, axis=(1, 2, 3)), dtype=y_pred.dtype)

    # Apply the weights to each loss.
    box_loss *= self._iou_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
439
    conf_loss *= self._object_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453

    # Add all the losses together then take the mean over the batches.
    loss = box_loss + class_loss + conf_loss
    loss = tf.reduce_mean(loss)

    # Reduce the mean of the losses to use as a metric.
    box_loss = tf.reduce_mean(box_loss)
    conf_loss = tf.reduce_mean(conf_loss)
    class_loss = tf.reduce_mean(class_loss)

    return (loss, box_loss, conf_loss, class_loss, loss, iou, pred_conf,
            ind_mask, grid_mask)

  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
Vishnu Banna's avatar
Vishnu Banna committed
454
    """This method is not specific to each loss path, but each loss type."""
Vishnu Banna's avatar
Vishnu Banna committed
455
456
457
458
    return loss / num_replicas_in_sync


class ScaledLoss(YoloLossBase):
Vishnu Banna's avatar
Vishnu Banna committed
459
  """This class implements the full logic for the scaled Yolo models."""
Vishnu Banna's avatar
Vishnu Banna committed
460
461

  def _build_per_path_attributes(self):
Vishnu Banna's avatar
Vishnu Banna committed
462
    """Paramterization of pair wise search and grid generators.
463
464
465
466

    Objects created here are used for box decoding and dynamic ground truth
    association.
    """
Vishnu Banna's avatar
Vishnu Banna committed
467
468
469
470
471
472
    self._anchor_generator = loss_utils.GridGenerator(
        anchors=self._anchors,
        scale_anchors=self._path_stride)

    if self._ignore_thresh > 0.0:
      self._search_pairs = loss_utils.PairWiseSearch(
473
          iou_type=self._loss_type, any_match=False, min_conf=0.25)
Vishnu Banna's avatar
Vishnu Banna committed
474

475
    self._cls_normalizer = self._cls_normalizer * self._classes / 80
Vishnu Banna's avatar
Vishnu Banna committed
476
477
    return

Vishnu Banna's avatar
Vishnu Banna committed
478
  def _compute_loss(self, true_counts, inds, y_true, boxes, classes, y_pred):
Vishnu Banna's avatar
Vishnu Banna committed
479
    """Per FPN path loss logic for Yolov4-csp, Yolov4-Large, and Yolov5."""
Vishnu Banna's avatar
Vishnu Banna committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    # Generate shape constants.
    shape = tf.shape(true_counts)
    batch_size, width, height, num = shape[0], shape[1], shape[2], shape[3]
    fwidth = tf.cast(width, tf.float32)
    fheight = tf.cast(height, tf.float32)

    # Cast all input compontnts to float32 and stop gradient to save memory.
    y_true = tf.cast(y_true, tf.float32)
    true_counts = tf.cast(true_counts, tf.float32)
    true_conf = tf.clip_by_value(true_counts, 0.0, 1.0)
    grid_points, anchor_grid = self._anchor_generator(
        width, height, batch_size, dtype=tf.float32)

    # Split the y_true list.
Vishnu Banna's avatar
Vishnu Banna committed
494
    (true_box, ind_mask, true_class) = tf.split(y_true, [4, 1, 1], axis=-1)
Vishnu Banna's avatar
Vishnu Banna committed
495
496
497
498
499
500
501
502
503
504
    grid_mask = true_conf = tf.squeeze(true_conf, axis=-1)
    true_class = tf.squeeze(true_class, axis=-1)
    num_objs = tf.cast(tf.reduce_sum(ind_mask), dtype=y_pred.dtype)

    # Split up the predicitons.
    y_pred = tf.cast(
        tf.reshape(y_pred, [batch_size, width, height, num, -1]), tf.float32)
    pred_box, pred_conf, pred_class = tf.split(y_pred, [4, 1, -1], axis=-1)

    # Decode the boxes for loss compute.
Vishnu Banna's avatar
Vishnu Banna committed
505
    scale, pred_box, pbg = self._decode_boxes(
Vishnu Banna's avatar
Vishnu Banna committed
506
507
508
509
510
511
512
        fwidth, fheight, pred_box, anchor_grid, grid_points, darknet=False)

    # If the ignore threshold is enabled, search all boxes ignore all
    # IOU valeus larger than the ignore threshold that are not in the
    # noted ground truth list.
    if self._ignore_thresh != 0.0:
      (_, obj_mask) = self._tiled_global_box_search(
Vishnu Banna's avatar
Vishnu Banna committed
513
          pbg,
Vishnu Banna's avatar
Vishnu Banna committed
514
515
516
517
518
          tf.stop_gradient(tf.sigmoid(pred_class)),
          boxes,
          classes,
          true_conf,
          smoothed=False,
Vishnu Banna's avatar
Vishnu Banna committed
519
          scale=None)
Vishnu Banna's avatar
Vishnu Banna committed
520
521
522

    # Scale and shift and select the ground truth boxes
    # and predictions to the prediciton domain.
523
524
525
    if self._box_type == 'anchor_free':
      true_box = loss_utils.apply_mask(ind_mask,
                                       (scale * self._path_stride * true_box))
Vishnu Banna's avatar
Vishnu Banna committed
526
527
528
529
530
    else:
      offset = tf.cast(
          tf.gather_nd(grid_points, inds, batch_dims=1), true_box.dtype)
      offset = tf.concat([offset, tf.zeros_like(offset)], axis=-1)
      true_box = loss_utils.apply_mask(ind_mask, (scale * true_box) - offset)
Vishnu Banna's avatar
Vishnu Banna committed
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    pred_box = loss_utils.apply_mask(ind_mask,
                                     tf.gather_nd(pred_box, inds, batch_dims=1))

    # Select the correct/used prediction classes.
    true_class = tf.one_hot(
        tf.cast(true_class, tf.int32),
        depth=tf.shape(pred_class)[-1],
        dtype=pred_class.dtype)
    true_class = loss_utils.apply_mask(ind_mask, true_class)
    pred_class = loss_utils.apply_mask(
        ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))

    # Compute the box loss.
    _, iou, box_loss = self.box_loss(true_box, pred_box, darknet=False)
    box_loss = loss_utils.apply_mask(tf.squeeze(ind_mask, axis=-1), box_loss)
    box_loss = math_ops.divide_no_nan(tf.reduce_sum(box_loss), num_objs)

    # Use the box IOU to build the map for confidence loss computation.
    iou = tf.maximum(tf.stop_gradient(iou), 0.0)
    smoothed_iou = ((
        (1 - self._objectness_smooth) * tf.cast(ind_mask, iou.dtype)) +
                    self._objectness_smooth * tf.expand_dims(iou, axis=-1))
    smoothed_iou = loss_utils.apply_mask(ind_mask, smoothed_iou)
    true_conf = loss_utils.build_grid(
        inds, smoothed_iou, pred_conf, ind_mask, update=self._update_on_repeat)
    true_conf = tf.squeeze(true_conf, axis=-1)

    # Compute the cross entropy loss for the confidence map.
    bce = tf.keras.losses.binary_crossentropy(
        tf.expand_dims(true_conf, axis=-1), pred_conf, from_logits=True)
    if self._ignore_thresh != 0.0:
      bce = loss_utils.apply_mask(obj_mask, bce)
Vishnu Banna's avatar
Vishnu Banna committed
563
564
565
      conf_loss = tf.reduce_sum(bce) / tf.reduce_sum(obj_mask)
    else:
      conf_loss = tf.reduce_mean(bce)
Vishnu Banna's avatar
Vishnu Banna committed
566
567
568
569
570
571
572
573
574
575
576
577
578
579

    # Compute the cross entropy loss for the class maps.
    class_loss = tf.keras.losses.binary_crossentropy(
        true_class,
        pred_class,
        label_smoothing=self._label_smoothing,
        from_logits=True)
    class_loss = loss_utils.apply_mask(
        tf.squeeze(ind_mask, axis=-1), class_loss)
    class_loss = math_ops.divide_no_nan(tf.reduce_sum(class_loss), num_objs)

    # Apply the weights to each loss.
    box_loss *= self._iou_normalizer
    class_loss *= self._cls_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
580
    conf_loss *= self._object_normalizer
Vishnu Banna's avatar
Vishnu Banna committed
581
582
583
584
585
586
587
588

    # Add all the losses together then take the sum over the batches.
    mean_loss = box_loss + class_loss + conf_loss
    loss = mean_loss * tf.cast(batch_size, mean_loss.dtype)

    return (loss, box_loss, conf_loss, class_loss, mean_loss, iou, pred_conf,
            ind_mask, grid_mask)

589
590
  def post_path_aggregation(self, loss, box_loss, conf_loss, class_loss,
                            ground_truths, predictions):
591
592
593
594
595
596
    """This method allows for post processing of a loss value.

    By default the model will have about 3 FPN levels {3, 4, 5}, on
    larger model that have more like 4 or 5 FPN levels the loss needs to
    be scaled such that the total update is scaled to the same effective
    magintude as the model with 3 FPN levels. This helps to prevent gradient
Vishnu Banna's avatar
Vishnu Banna committed
597
    explosions.
598
599

    Args:
Vishnu Banna's avatar
Vishnu Banna committed
600
      loss: `tf.float` scalar for the actual loss.
601
602
603
      box_loss: `tf.float` for the loss on the boxs only.
      conf_loss: `tf.float` for the loss on the confidences only.
      class_loss: `tf.float` for the loss on the classes only.
Vishnu Banna's avatar
Vishnu Banna committed
604
605
      ground_truths: `Dict` holding all the ground truth tensors.
      predictions: `Dict` holding all the predicted values.
606
    Returns:
Vishnu Banna's avatar
Vishnu Banna committed
607
      loss: `tf.float` scalar for the scaled loss.
Vishnu Banna's avatar
Vishnu Banna committed
608
      scale: `tf.float` how much the loss was scaled by.
Vishnu Banna's avatar
Vishnu Banna committed
609
    """
Vishnu Banna's avatar
Vishnu Banna committed
610
    scale = tf.stop_gradient(3 / len(list(predictions.keys())))
611
    return loss * scale, 1 / scale
Vishnu Banna's avatar
Vishnu Banna committed
612
613

  def cross_replica_aggregation(self, loss, num_replicas_in_sync):
Vishnu Banna's avatar
Vishnu Banna committed
614
    """This method is not specific to each loss path, but each loss type."""
Vishnu Banna's avatar
Vishnu Banna committed
615
616
617
    return loss


Vishnu Banna's avatar
Vishnu Banna committed
618
619
class YoloLoss:
  """This class implements the aggregated loss across YOLO model FPN levels."""
Vishnu Banna's avatar
Vishnu Banna committed
620
621
622
623
624
625
626
627
628
629
630

  def __init__(self,
               keys,
               classes,
               anchors,
               path_strides=None,
               truth_thresholds=None,
               ignore_thresholds=None,
               loss_types=None,
               iou_normalizers=None,
               cls_normalizers=None,
Vishnu Banna's avatar
Vishnu Banna committed
631
               object_normalizers=None,
Vishnu Banna's avatar
Vishnu Banna committed
632
633
634
635
636
637
638
               objectness_smooths=None,
               box_types=None,
               scale_xys=None,
               max_deltas=None,
               label_smoothing=0.0,
               use_scaled_loss=False,
               update_on_repeat=True):
639
    """Loss Function Initialization.
Vishnu Banna's avatar
Vishnu Banna committed
640
641
642
643

    Args:
      keys: `List[str]` indicating the name of the FPN paths that need to be
        optimized.
644
645
646
      classes: `int` for the number of classes
      anchors: `List[List[int]]` for the anchor boxes that are used in the model
        at all levels. For anchor free prediction set the anchor list to be the
Vishnu Banna's avatar
Vishnu Banna committed
647
        same as the image resolution.
648
      path_strides: `Dict[int]` for how much to scale this level to get the
Vishnu Banna's avatar
Vishnu Banna committed
649
        orginal input shape for each FPN path.
650
      truth_thresholds: `Dict[float]` for the IOU value over which the loss is
Vishnu Banna's avatar
Vishnu Banna committed
651
        propagated despite a detection being made for each FPN path.
652
653
      ignore_thresholds: `Dict[float]` for the IOU value over which the loss is
        not propagated, and a detection is assumed to have been made for each
Vishnu Banna's avatar
Vishnu Banna committed
654
        FPN path.
655
      loss_types: `Dict[str]` for the typeof iou loss to use with in {ciou,
Vishnu Banna's avatar
Vishnu Banna committed
656
        diou, giou, iou} for each FPN path.
657
      iou_normalizers: `Dict[float]` for how much to scale the loss on the IOU
Vishnu Banna's avatar
Vishnu Banna committed
658
        or the boxes for each FPN path.
659
      cls_normalizers: `Dict[float]` for how much to scale the loss on the
Vishnu Banna's avatar
Vishnu Banna committed
660
        classes for each FPN path.
661
662
      object_normalizers: `Dict[float]` for how much to scale loss on the
        detection map for each FPN path.
663
      objectness_smooths: `Dict[float]` for how much to smooth the loss on the
Vishnu Banna's avatar
Vishnu Banna committed
664
        detection map for each FPN path.
665
666
667
668
669
670
671
      box_types: `Dict[bool]` for which scaling type to use for each FPN path.
      scale_xys:  `Dict[float]` values inidcating how far each pixel can see
        outside of its containment of 1.0. a value of 1.2 indicates there is a
        20% extended radius around each pixel that this specific pixel can
        predict values for a center at. the center can range from 0 - value/2 to
        1 + value/2, this value is set in the yolo filter, and resused here.
        there should be one value for scale_xy for each level from min_level to
Vishnu Banna's avatar
Vishnu Banna committed
672
        max_level. One for each FPN path.
673
      max_deltas: `Dict[float]` for gradient clipping to apply to the box loss
Vishnu Banna's avatar
Vishnu Banna committed
674
        for each FPN path.
675
      label_smoothing: `Dict[float]` for how much to smooth the loss on the
Vishnu Banna's avatar
Vishnu Banna committed
676
        classes for each FPN path.
677
678
679
680
      use_scaled_loss: `bool` for whether to use the scaled loss or the
        traditional loss.
      update_on_repeat: `bool` for whether to replace with the newest or the
        best value when an index is consumed by multiple objects.
Vishnu Banna's avatar
Vishnu Banna committed
681
    """
682
683

    losses = {'darknet': DarknetLoss, 'scaled': ScaledLoss}
Vishnu Banna's avatar
Vishnu Banna committed
684

Vishnu Banna's avatar
Vishnu Banna committed
685
    if use_scaled_loss:
686
      loss_type = 'scaled'
Vishnu Banna's avatar
Vishnu Banna committed
687
    else:
688
      loss_type = 'darknet'
Vishnu Banna's avatar
Vishnu Banna committed
689
690
691

    self._loss_dict = {}
    for key in keys:
692
      self._loss_dict[key] = losses[loss_type](
Vishnu Banna's avatar
Vishnu Banna committed
693
          classes=classes,
Vishnu Banna's avatar
Vishnu Banna committed
694
          anchors=anchors[key],
Vishnu Banna's avatar
Vishnu Banna committed
695
696
697
698
699
          truth_thresh=truth_thresholds[key],
          ignore_thresh=ignore_thresholds[key],
          loss_type=loss_types[key],
          iou_normalizer=iou_normalizers[key],
          cls_normalizer=cls_normalizers[key],
Vishnu Banna's avatar
Vishnu Banna committed
700
          object_normalizer=object_normalizers[key],
Vishnu Banna's avatar
Vishnu Banna committed
701
702
703
704
705
706
707
708
          box_type=box_types[key],
          objectness_smooth=objectness_smooths[key],
          max_delta=max_deltas[key],
          path_stride=path_strides[key],
          scale_x_y=scale_xys[key],
          update_on_repeat=update_on_repeat,
          label_smoothing=label_smoothing)

Vishnu Banna's avatar
Vishnu Banna committed
709
  def __call__(self, ground_truth, predictions):
710
    metric_dict = collections.defaultdict(dict)
Vishnu Banna's avatar
Vishnu Banna committed
711
712
713
714
715
716
717
718
    metric_dict['net']['box'] = 0
    metric_dict['net']['class'] = 0
    metric_dict['net']['conf'] = 0

    loss_val, metric_loss = 0, 0
    num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync

    for key in predictions.keys():
719
720
721
722
723
724
725
      (loss, loss_box, loss_conf, loss_class, mean_loss, avg_iou,
       avg_obj) = self._loss_dict[key](ground_truth['true_conf'][key],
                                       ground_truth['inds'][key],
                                       ground_truth['upds'][key],
                                       ground_truth['bbox'],
                                       ground_truth['classes'],
                                       predictions[key])
Vishnu Banna's avatar
Vishnu Banna committed
726
727
728

      # after computing the loss, scale loss as needed for aggregation
      # across FPN levels
729
730
      loss, scale = self._loss_dict[key].post_path_aggregation(
          loss, loss_box, loss_conf, loss_class, ground_truth, predictions)
731

Vishnu Banna's avatar
Vishnu Banna committed
732
733
      # after completing the scaling of the loss on each replica, handle
      # scaling the loss for mergeing the loss across replicas
734
735
736
      loss = self._loss_dict[key].cross_replica_aggregation(
          loss, num_replicas_in_sync)
      loss_val += loss
Vishnu Banna's avatar
Vishnu Banna committed
737
738
739

      # detach all the below gradients: none of them should make a
      # contribution to the gradient form this point forwards
740
741
      metric_loss += tf.stop_gradient(mean_loss / scale)
      metric_dict[key]['loss'] = tf.stop_gradient(mean_loss / scale)
742
743
      metric_dict[key]['avg_iou'] = tf.stop_gradient(avg_iou)
      metric_dict[key]['avg_obj'] = tf.stop_gradient(avg_obj)
Vishnu Banna's avatar
Vishnu Banna committed
744

745
746
747
      metric_dict['net']['box'] += tf.stop_gradient(loss_box / scale)
      metric_dict['net']['class'] += tf.stop_gradient(loss_class / scale)
      metric_dict['net']['conf'] += tf.stop_gradient(loss_conf / scale)
Vishnu Banna's avatar
Vishnu Banna committed
748

Vishnu Banna's avatar
Vishnu Banna committed
749
    return loss_val, metric_loss, metric_dict