loss_utils.py 22.9 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Vishnu Banna's avatar
Vishnu Banna committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Yolo loss utility functions."""
Vishnu Banna's avatar
Vishnu Banna committed
16

Vishnu Banna's avatar
Vishnu Banna committed
17
import numpy as np
18
19
import tensorflow as tf

Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
from official.projects.yolo.ops import box_ops
from official.projects.yolo.ops import math_ops
22

Vishnu Banna's avatar
Vishnu Banna committed
23
24

@tf.custom_gradient
Vishnu Banna's avatar
Vishnu Banna committed
25
def sigmoid_bce(y, x_prime, label_smoothing):
Vishnu Banna's avatar
Vishnu Banna committed
26
  """Applies the Sigmoid Cross Entropy Loss.
27
28
29

  Implements the same derivative as that found in the Darknet C library.
  The derivative of this method is not the same as the standard binary cross
Vishnu Banna's avatar
Vishnu Banna committed
30
  entropy with logits function.
31
32

  The BCE with logits function equation is as follows:
Vishnu Banna's avatar
Vishnu Banna committed
33
    x = 1 / (1 + exp(-x_prime))
34
35
36
    bce = -ylog(x) - (1 - y)log(1 - x)

  The standard BCE with logits function derivative is as follows:
Vishnu Banna's avatar
Vishnu Banna committed
37
38
39
    dloss = -y/x + (1-y)/(1-x)
    dsigmoid = x * (1 - x)
    dx = dloss * dsigmoid
40
41

  This derivative can be reduced simply to:
Vishnu Banna's avatar
Vishnu Banna committed
42
    dx = (-y + x)
43
44
45
46

  This simplification is used by the darknet library in order to improve
  training stability. The gradient is almost the same
  as tf.keras.losses.binary_crossentropy but varies slightly and
Vishnu Banna's avatar
Vishnu Banna committed
47
  yields different performance.
48
49
50
51

  Args:
    y: `Tensor` holding ground truth data.
    x_prime: `Tensor` holding the predictions prior to application of the
Vishnu Banna's avatar
Vishnu Banna committed
52
      sigmoid operation.
53
    label_smoothing: float value between 0.0 and 1.0 indicating the amount of
Vishnu Banna's avatar
Vishnu Banna committed
54
      smoothing to apply to the data.
55
56

  Returns:
Vishnu Banna's avatar
Vishnu Banna committed
57
    bce: Tensor of the be applied loss values.
58
    delta: callable function indicating the custom gradient for this operation.
Vishnu Banna's avatar
Vishnu Banna committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  """

  eps = 1e-9
  x = tf.math.sigmoid(x_prime)
  y = tf.stop_gradient(y * (1 - label_smoothing) + 0.5 * label_smoothing)
  bce = -y * tf.math.log(x + eps) - (1 - y) * tf.math.log(1 - x + eps)

  def delta(dpass):
    x = tf.math.sigmoid(x_prime)
    dx = (-y + x) * dpass
    dy = tf.zeros_like(y)
    return dy, dx, 0.0

  return bce, delta


def apply_mask(mask, x, value=0):
76
77
78
79
80
81
82
83
84
  """This function is used for gradient masking.

  The YOLO loss function makes extensive use of dynamically shaped tensors.
  To allow this use case on the TPU while preserving the gradient correctly
  for back propagation we use this masking function to use a tf.where operation
  to hard set masked location to have a gradient and a value of zero.

  Args:
    mask: A `Tensor` with the same shape as x used to select values of
Vishnu Banna's avatar
Vishnu Banna committed
85
86
      importance.
    x: A `Tensor` with the same shape as mask that will be getting masked.
87
88
89
    value: `float` constant additive value.

  Returns:
Vishnu Banna's avatar
Vishnu Banna committed
90
91
92
93
94
95
96
97
    x: A masked `Tensor` with the same shape as x.
  """
  mask = tf.cast(mask, tf.bool)
  masked = tf.where(mask, x, tf.zeros_like(x) + value)
  return masked


def build_grid(indexes, truths, preds, ind_mask, update=False, grid=None):
98
  """This function is used to broadcast elements into the output shape.
Vishnu Banna's avatar
Vishnu Banna committed
99

100
101
  This function is used to broadcasts a list of truths into the correct index
  in the output shape. This is used for the ground truth map construction in
Vishnu Banna's avatar
Vishnu Banna committed
102
  the scaled loss and the classification map in the darknet loss.
103

Vishnu Banna's avatar
Vishnu Banna committed
104
105
106
107
108
109
110
  Args:
    indexes: A `Tensor` for the indexes
    truths: A `Tensor` for the ground truth.
    preds: A `Tensor` for the predictions.
    ind_mask: A `Tensor` for the index masks.
    update: A `bool` for updating the grid.
    grid: A `Tensor` for the grid.
111

Vishnu Banna's avatar
Vishnu Banna committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
  Returns:
    grid: A `Tensor` representing the augmented grid.
  """
  # this function is used to broadcast all the indexes to the correct
  # into the correct ground truth mask, used for iou detection map
  # in the scaled loss and the classification mask in the darknet loss
  num_flatten = tf.shape(preds)[-1]

  # is there a way to verify that we are not on the CPU?
  ind_mask = tf.cast(ind_mask, indexes.dtype)

  # find all the batch indexes using the cumulated sum of a ones tensor
  # cumsum(ones) - 1 yeild the zero indexed batches
  bhep = tf.reduce_max(tf.ones_like(indexes), axis=-1, keepdims=True)
  bhep = tf.math.cumsum(bhep, axis=0) - 1

  # concatnate the batch sizes to the indexes
  indexes = tf.concat([bhep, indexes], axis=-1)
  indexes = apply_mask(tf.cast(ind_mask, indexes.dtype), indexes)
  indexes = (indexes + (ind_mask - 1))

Vishnu Banna's avatar
Vishnu Banna committed
133
134
135
136
  # mask truths
  truths = apply_mask(tf.cast(ind_mask, truths.dtype), truths)
  truths = (truths + (tf.cast(ind_mask, truths.dtype) - 1))

Vishnu Banna's avatar
Vishnu Banna committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  # reshape the indexes into the correct shape for the loss,
  # just flatten all indexes but the last
  indexes = tf.reshape(indexes, [-1, 4])

  # also flatten the ground truth value on all axis but the last
  truths = tf.reshape(truths, [-1, num_flatten])

  # build a zero grid in the samve shape as the predicitons
  if grid is None:
    grid = tf.zeros_like(preds)
  # remove invalid values from the truths that may have
  # come up from computation, invalid = nan and inf
  truths = math_ops.rm_nan_inf(truths)

  # scatter update the zero grid
  if update:
    grid = tf.tensor_scatter_nd_update(grid, indexes, truths)
  else:
    grid = tf.tensor_scatter_nd_max(grid, indexes, truths)

  # stop gradient and return to avoid TPU errors and save compute
  # resources
  return grid


Vishnu Banna's avatar
Vishnu Banna committed
162
class GridGenerator:
Vishnu Banna's avatar
Vishnu Banna committed
163
  """Grid generator that generates anchor grids for box decoding."""
Vishnu Banna's avatar
Vishnu Banna committed
164

Vishnu Banna's avatar
Vishnu Banna committed
165
  def __init__(self, anchors, scale_anchors=None):
Vishnu Banna's avatar
Vishnu Banna committed
166
    """Initialize Grid Generator.
167

Vishnu Banna's avatar
Vishnu Banna committed
168
    Args:
169
170
171
      anchors: A `List[List[int]]` for the anchor boxes that are used in the
        model at all levels.
      scale_anchors: An `int` for how much to scale this level to get the
Vishnu Banna's avatar
Vishnu Banna committed
172
173
174
175
176
177
178
        original input shape.
    """
    self.dtype = tf.keras.backend.floatx()
    self._scale_anchors = scale_anchors
    self._anchors = tf.convert_to_tensor(anchors)
    return

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
179
  def _build_grid_points(self, lheight, lwidth, anchors, dtype):
Vishnu Banna's avatar
Vishnu Banna committed
180
    """Generate a grid of fixed grid edges for box center decoding."""
Vishnu Banna's avatar
Vishnu Banna committed
181
182
183
184
    with tf.name_scope('center_grid'):
      y = tf.range(0, lheight)
      x = tf.range(0, lwidth)
      x_left = tf.tile(
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
185
186
          tf.transpose(tf.expand_dims(x, axis=-1), perm=[1, 0]), [lheight, 1])
      y_left = tf.tile(tf.expand_dims(y, axis=-1), [1, lwidth])
Vishnu Banna's avatar
Vishnu Banna committed
187
188
      x_y = tf.stack([x_left, y_left], axis=-1)
      x_y = tf.cast(x_y, dtype=dtype)
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
189
      num = tf.shape(anchors)[0]
Vishnu Banna's avatar
Vishnu Banna committed
190
191
192
193
      x_y = tf.expand_dims(
          tf.tile(tf.expand_dims(x_y, axis=-2), [1, 1, num, 1]), axis=0)
    return x_y

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
194
  def _build_anchor_grid(self, height, width, anchors, dtype):
Vishnu Banna's avatar
Vishnu Banna committed
195
    """Get the transformed anchor boxes for each dimention."""
Vishnu Banna's avatar
Vishnu Banna committed
196
197
198
199
    with tf.name_scope('anchor_grid'):
      num = tf.shape(anchors)[0]
      anchors = tf.cast(anchors, dtype=dtype)
      anchors = tf.reshape(anchors, [1, 1, 1, num, 2])
200
      anchors = tf.tile(anchors, [1, tf.cast(height, tf.int32),
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
201
                                  tf.cast(width, tf.int32), 1, 1])
Vishnu Banna's avatar
Vishnu Banna committed
202
203
204
205
206
    return anchors

  def _extend_batch(self, grid, batch_size):
    return tf.tile(grid, [batch_size, 1, 1, 1, 1])

Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
207
  def __call__(self, height, width, batch_size, dtype=None):
Vishnu Banna's avatar
Vishnu Banna committed
208
209
210
211
    if dtype is None:
      self.dtype = tf.keras.backend.floatx()
    else:
      self.dtype = dtype
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
212
    grid_points = self._build_grid_points(height, width, self._anchors,
Vishnu Banna's avatar
Vishnu Banna committed
213
214
                                          self.dtype)
    anchor_grid = self._build_anchor_grid(
Vishnu Banna's avatar
kmeans  
Vishnu Banna committed
215
        height, width,
Vishnu Banna's avatar
Vishnu Banna committed
216
217
218
219
220
221
222
223
224
        tf.cast(self._anchors, self.dtype) /
        tf.cast(self._scale_anchors, self.dtype), self.dtype)

    grid_points = self._extend_batch(grid_points, batch_size)
    anchor_grid = self._extend_batch(anchor_grid, batch_size)
    return grid_points, anchor_grid


TILE_SIZE = 50
225
226


Vishnu Banna's avatar
Vishnu Banna committed
227
class PairWiseSearch:
228
229
230
231
232
  """Apply a pairwise search between the ground truth and the labels.

  The goal is to indicate the locations where the predictions overlap with
  ground truth for dynamic ground truth associations.
  """
Vishnu Banna's avatar
Vishnu Banna committed
233
234
235

  def __init__(self,
               iou_type='iou',
236
               any_match=True,
Vishnu Banna's avatar
Vishnu Banna committed
237
238
239
240
241
               min_conf=0.0,
               track_boxes=False,
               track_classes=False):
    """Initialization of Pair Wise Search.

242
    Args:
Vishnu Banna's avatar
Vishnu Banna committed
243
      iou_type: An `str` for the iou type to use.
244
      any_match: A `bool` for any match(no class match).
Vishnu Banna's avatar
Vishnu Banna committed
245
246
247
248
249
      min_conf: An `int` for minimum confidence threshold.
      track_boxes: A `bool` dynamic box assignment.
      track_classes: A `bool` dynamic class assignment.
    """
    self.iou_type = iou_type
250
    self._any = any_match
Vishnu Banna's avatar
Vishnu Banna committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    self._min_conf = min_conf
    self._track_boxes = track_boxes
    self._track_classes = track_classes
    return

  def box_iou(self, true_box, pred_box):
    # based on the type of loss, compute the iou loss for a box
    # compute_<name> indicated the type of iou to use
    if self.iou_type == 'giou':
      _, iou = box_ops.compute_giou(true_box, pred_box)
    elif self.iou_type == 'ciou':
      _, iou = box_ops.compute_ciou(true_box, pred_box)
    else:
      iou = box_ops.compute_iou(true_box, pred_box)
    return iou

  def _search_body(self, pred_box, pred_class, boxes, classes, running_boxes,
                   running_classes, max_iou, idx):
269
270
    """Main search fn."""

Vishnu Banna's avatar
Vishnu Banna committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    # capture the batch size to be used, and gather a slice of
    # boxes from the ground truth. currently TILE_SIZE = 50, to
    # save memory
    batch_size = tf.shape(boxes)[0]
    box_slice = tf.slice(boxes, [0, idx * TILE_SIZE, 0],
                         [batch_size, TILE_SIZE, 4])

    # match the dimentions of the slice to the model predictions
    # shape: [batch_size, 1, 1, num, TILE_SIZE, 4]
    box_slice = tf.expand_dims(box_slice, axis=1)
    box_slice = tf.expand_dims(box_slice, axis=1)
    box_slice = tf.expand_dims(box_slice, axis=1)

    box_grid = tf.expand_dims(pred_box, axis=-2)

    # capture the classes
    class_slice = tf.slice(classes, [0, idx * TILE_SIZE],
                           [batch_size, TILE_SIZE])
    class_slice = tf.expand_dims(class_slice, axis=1)
    class_slice = tf.expand_dims(class_slice, axis=1)
    class_slice = tf.expand_dims(class_slice, axis=1)

    iou = self.box_iou(box_slice, box_grid)

    if self._min_conf > 0.0:
      if not self._any:
        class_grid = tf.expand_dims(pred_class, axis=-2)
        class_mask = tf.one_hot(
            tf.cast(class_slice, tf.int32),
            depth=tf.shape(pred_class)[-1],
            dtype=pred_class.dtype)
        class_mask = tf.reduce_any(tf.equal(class_mask, class_grid), axis=-1)
      else:
        class_mask = tf.reduce_max(pred_class, axis=-1, keepdims=True)
      class_mask = tf.cast(class_mask, iou.dtype)
      iou *= class_mask

    max_iou_ = tf.concat([max_iou, iou], axis=-1)
    max_iou = tf.reduce_max(max_iou_, axis=-1, keepdims=True)
    ind = tf.expand_dims(tf.argmax(max_iou_, axis=-1), axis=-1)

    if self._track_boxes:
      running_boxes = tf.expand_dims(running_boxes, axis=-2)
      box_slice = tf.zeros_like(running_boxes) + box_slice
      box_slice = tf.concat([running_boxes, box_slice], axis=-2)
      running_boxes = tf.gather_nd(box_slice, ind, batch_dims=4)

    if self._track_classes:
      running_classes = tf.expand_dims(running_classes, axis=-1)
      class_slice = tf.zeros_like(running_classes) + class_slice
      class_slice = tf.concat([running_classes, class_slice], axis=-1)
      running_classes = tf.gather_nd(class_slice, ind, batch_dims=4)

    return (pred_box, pred_class, boxes, classes, running_boxes,
            running_classes, max_iou, idx + 1)

  def __call__(self,
               pred_boxes,
               pred_classes,
               boxes,
               classes,
               clip_thresh=0.0):
    num_boxes = tf.shape(boxes)[-2]
    num_tiles = (num_boxes // TILE_SIZE) - 1

    if self._min_conf > 0.0:
      pred_classes = tf.cast(pred_classes > self._min_conf, pred_classes.dtype)

339
340
341
    def _loop_cond(unused_pred_box, unused_pred_class, boxes, unused_classes,
                   unused_running_boxes, unused_running_classes, unused_max_iou,
                   idx):
Vishnu Banna's avatar
Vishnu Banna committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

      # check that the slice has boxes that all zeros
      batch_size = tf.shape(boxes)[0]
      box_slice = tf.slice(boxes, [0, idx * TILE_SIZE, 0],
                           [batch_size, TILE_SIZE, 4])

      return tf.logical_and(idx < num_tiles,
                            tf.math.greater(tf.reduce_sum(box_slice), 0))

    running_boxes = tf.zeros_like(pred_boxes)
    running_classes = tf.zeros_like(tf.reduce_sum(running_boxes, axis=-1))
    max_iou = tf.zeros_like(tf.reduce_sum(running_boxes, axis=-1))
    max_iou = tf.expand_dims(max_iou, axis=-1)

    (pred_boxes, pred_classes, boxes, classes, running_boxes, running_classes,
357
     max_iou, _) = tf.while_loop(_loop_cond, self._search_body, [
Vishnu Banna's avatar
Vishnu Banna committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
         pred_boxes, pred_classes, boxes, classes, running_boxes,
         running_classes, max_iou,
         tf.constant(0)
     ])

    mask = tf.cast(max_iou > clip_thresh, running_boxes.dtype)
    running_boxes *= mask
    running_classes *= tf.squeeze(mask, axis=-1)
    max_iou *= mask
    max_iou = tf.squeeze(max_iou, axis=-1)
    mask = tf.squeeze(mask, axis=-1)

    return (tf.stop_gradient(running_boxes), tf.stop_gradient(running_classes),
            tf.stop_gradient(max_iou), tf.stop_gradient(mask))


Vishnu Banna's avatar
Vishnu Banna committed
374
def average_iou(iou):
375
376
  """Computes the average intersection over union without counting locations.

Vishnu Banna's avatar
Vishnu Banna committed
377
  where the iou is zero.
378

Vishnu Banna's avatar
Vishnu Banna committed
379
380
  Args:
    iou: A `Tensor` representing the iou values.
381

Vishnu Banna's avatar
Vishnu Banna committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
  Returns:
    tf.stop_gradient(avg_iou): A `Tensor` representing average
     intersection over union.
  """
  iou_sum = tf.reduce_sum(iou, axis=tf.range(1, tf.shape(tf.shape(iou))[0]))
  counts = tf.cast(
      tf.math.count_nonzero(iou, axis=tf.range(1,
                                               tf.shape(tf.shape(iou))[0])),
      iou.dtype)
  avg_iou = tf.reduce_mean(math_ops.divide_no_nan(iou_sum, counts))
  return tf.stop_gradient(avg_iou)


def _scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points,
                 scale_xy):
397
  """Decodes models boxes applying and exponential to width and height maps."""
Vishnu Banna's avatar
Vishnu Banna committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
  # split the boxes
  pred_xy = encoded_boxes[..., 0:2]
  pred_wh = encoded_boxes[..., 2:4]

  # build a scaling tensor to get the offset of th ebox relative to the image
  scaler = tf.convert_to_tensor([height, width, height, width])
  scale_xy = tf.cast(scale_xy, encoded_boxes.dtype)

  # apply the sigmoid
  pred_xy = tf.math.sigmoid(pred_xy)

  # scale the centers and find the offset of each box relative to
  # their center pixel
  pred_xy = pred_xy * scale_xy - 0.5 * (scale_xy - 1)

  # scale the offsets and add them to the grid points or a tensor that is
  # the realtive location of each pixel
  box_xy = grid_points + pred_xy

  # scale the width and height of the predictions and corlate them
  # to anchor boxes
  box_wh = tf.math.exp(pred_wh) * anchor_grid

  # build the final predicted box
  scaled_box = tf.concat([box_xy, box_wh], axis=-1)
  pred_box = scaled_box / scaler

  # shift scaled boxes
  scaled_box = tf.concat([pred_xy, box_wh], axis=-1)
  return (scaler, scaled_box, pred_box)


@tf.custom_gradient
def _darknet_boxes(encoded_boxes, width, height, anchor_grid, grid_points,
                   max_delta, scale_xy):
433
  """Wrapper for _scale_boxes to implement a custom gradient."""
Vishnu Banna's avatar
Vishnu Banna committed
434
435
436
437
  (scaler, scaled_box, pred_box) = _scale_boxes(encoded_boxes, width, height,
                                                anchor_grid, grid_points,
                                                scale_xy)

438
  def delta(unused_dy_scaler, dy_scaled, dy):
Vishnu Banna's avatar
Vishnu Banna committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    dy_xy, dy_wh = tf.split(dy, 2, axis=-1)
    dy_xy_, dy_wh_ = tf.split(dy_scaled, 2, axis=-1)

    # add all the gradients that may have been applied to the
    # boxes and those that have been applied to the width and height
    dy_wh += dy_wh_
    dy_xy += dy_xy_

    # propagate the exponential applied to the width and height in
    # order to ensure the gradient propagated is of the correct
    # magnitude
    pred_wh = encoded_boxes[..., 2:4]
    dy_wh *= tf.math.exp(pred_wh)

    dbox = tf.concat([dy_xy, dy_wh], axis=-1)

    # apply the gradient clipping to xy and wh
    dbox = math_ops.rm_nan_inf(dbox)
    delta = tf.cast(max_delta, dbox.dtype)
    dbox = tf.clip_by_value(dbox, -delta, delta)
    return dbox, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

  return (scaler, scaled_box, pred_box), delta


def _new_coord_scale_boxes(encoded_boxes, width, height, anchor_grid,
                           grid_points, scale_xy):
466
  """Decodes models boxes by squaring and scaling the width and height maps."""
Vishnu Banna's avatar
Vishnu Banna committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
  # split the boxes
  pred_xy = encoded_boxes[..., 0:2]
  pred_wh = encoded_boxes[..., 2:4]

  # build a scaling tensor to get the offset of th ebox relative to the image
  scaler = tf.convert_to_tensor([height, width, height, width])
  scale_xy = tf.cast(scale_xy, pred_xy.dtype)

  # apply the sigmoid
  pred_xy = tf.math.sigmoid(pred_xy)
  pred_wh = tf.math.sigmoid(pred_wh)

  # scale the xy offset predictions according to the config
  pred_xy = pred_xy * scale_xy - 0.5 * (scale_xy - 1)

  # find the true offset from the grid points and the scaler
  # where the grid points are the relative offset of each pixel with
  # in the image
  box_xy = grid_points + pred_xy

  # decode the widht and height of the boxes and correlate them
  # to the anchor boxes
  box_wh = (2 * pred_wh)**2 * anchor_grid

  # build the final boxes
  scaled_box = tf.concat([box_xy, box_wh], axis=-1)
  pred_box = scaled_box / scaler

  # shift scaled boxes
  scaled_box = tf.concat([pred_xy, box_wh], axis=-1)
  return (scaler, scaled_box, pred_box)


@tf.custom_gradient
def _darknet_new_coord_boxes(encoded_boxes, width, height, anchor_grid,
                             grid_points, max_delta, scale_xy):
503
  """Wrapper for _new_coord_scale_boxes to implement a custom gradient."""
Vishnu Banna's avatar
Vishnu Banna committed
504
505
506
507
  (scaler, scaled_box,
   pred_box) = _new_coord_scale_boxes(encoded_boxes, width, height, anchor_grid,
                                      grid_points, scale_xy)

508
  def delta(unused_dy_scaler, dy_scaled, dy):
Vishnu Banna's avatar
Vishnu Banna committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    dy_xy, dy_wh = tf.split(dy, 2, axis=-1)
    dy_xy_, dy_wh_ = tf.split(dy_scaled, 2, axis=-1)

    # add all the gradients that may have been applied to the
    # boxes and those that have been applied to the width and height
    dy_wh += dy_wh_
    dy_xy += dy_xy_

    dbox = tf.concat([dy_xy, dy_wh], axis=-1)

    # apply the gradient clipping to xy and wh
    dbox = math_ops.rm_nan_inf(dbox)
    delta = tf.cast(max_delta, dbox.dtype)
    dbox = tf.clip_by_value(dbox, -delta, delta)
    return dbox, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

  return (scaler, scaled_box, pred_box), delta


Vishnu Banna's avatar
config  
Vishnu Banna committed
528
529
530
531
532
533
def _anchor_free_scale_boxes(encoded_boxes,
                             width,
                             height,
                             stride,
                             grid_points,
                             darknet=False):
534
  """Decode models boxes using FPN stride under anchor free conditions."""
535
  del darknet
Vishnu Banna's avatar
Vishnu Banna committed
536
537
538
539
540
541
542
543
544
  # split the boxes
  pred_xy = encoded_boxes[..., 0:2]
  pred_wh = encoded_boxes[..., 2:4]

  # build a scaling tensor to get the offset of th ebox relative to the image
  scaler = tf.convert_to_tensor([height, width, height, width])

  # scale the offsets and add them to the grid points or a tensor that is
  # the realtive location of each pixel
Vishnu Banna's avatar
config  
Vishnu Banna committed
545
  box_xy = (grid_points + pred_xy)
Vishnu Banna's avatar
Vishnu Banna committed
546
547
548

  # scale the width and height of the predictions and corlate them
  # to anchor boxes
Vishnu Banna's avatar
config  
Vishnu Banna committed
549
  box_wh = tf.math.exp(pred_wh)
Vishnu Banna's avatar
Vishnu Banna committed
550
551
552

  # build the final predicted box
  scaled_box = tf.concat([box_xy, box_wh], axis=-1)
Vishnu Banna's avatar
config  
Vishnu Banna committed
553
554

  # properly scaling boxes gradeints
Vishnu Banna's avatar
Vishnu Banna committed
555
  scaled_box = scaled_box * tf.cast(stride, scaled_box.dtype)
556
  pred_box = scaled_box / tf.cast(scaler * stride, scaled_box.dtype)
Vishnu Banna's avatar
Vishnu Banna committed
557
558
559
560
561
562
563
564
565
566
567
  return (scaler, scaled_box, pred_box)


def get_predicted_box(width,
                      height,
                      encoded_boxes,
                      anchor_grid,
                      grid_points,
                      scale_xy,
                      stride,
                      darknet=False,
568
                      box_type='original',
Vishnu Banna's avatar
Vishnu Banna committed
569
                      max_delta=np.inf):
570
571
572
573
  """Decodes the predicted boxes from the model format to a usable format.

  This function decodes the model outputs into the [x, y, w, h] format for
  use in the loss function as well as for use within the detection generator.
Vishnu Banna's avatar
Vishnu Banna committed
574

575
  Args:
Vishnu Banna's avatar
Vishnu Banna committed
576
577
    width: A `float` scalar indicating the width of the prediction layer.
    height: A `float` scalar indicating the height of the prediction layer
578
    encoded_boxes: A `Tensor` of shape [..., height, width, 4] holding encoded
Vishnu Banna's avatar
Vishnu Banna committed
579
      boxes.
580
581
582
    anchor_grid: A `Tensor` of shape [..., 1, 1, 2] holding the anchor boxes
      organized for box decoding, box width and height.
    grid_points: A `Tensor` of shape [..., height, width, 2] holding the anchor
Vishnu Banna's avatar
Vishnu Banna committed
583
      boxes for decoding the box centers.
584
585
586
587
    scale_xy: A `float` scaler used to indicate the range for each center
      outside of its given [..., i, j, 4] index, where i and j are indexing
      pixels along the width and height of the predicted output map.
    stride: An `int` defining the amount of down stride realtive to the input
Vishnu Banna's avatar
Vishnu Banna committed
588
      image.
589
590
    darknet: A `bool` used to select between custom gradient and default
      autograd.
Vishnu Banna's avatar
Vishnu Banna committed
591
    box_type: An `str` indicating the type of box encoding that is being used.
592
593
594
595
    max_delta: A `float` scaler used for gradient clipping in back propagation.

  Returns:
    scaler: A `Tensor` of shape [4] returned to allow the scaling of the ground
Vishnu Banna's avatar
Vishnu Banna committed
596
      truth boxes to be of the same magnitude as the decoded predicted boxes.
597
    scaled_box: A `Tensor` of shape [..., height, width, 4] with the predicted
Vishnu Banna's avatar
Vishnu Banna committed
598
      boxes.
599
600
601
    pred_box: A `Tensor` of shape [..., height, width, 4] with the predicted
      boxes divided by the scaler parameter used to put all boxes in the [0, 1]
      range.
Vishnu Banna's avatar
Vishnu Banna committed
602
603
  """
  if box_type == 'anchor_free':
604
605
    (scaler, scaled_box, pred_box) = _anchor_free_scale_boxes(
        encoded_boxes, width, height, stride, grid_points, darknet=darknet)
Vishnu Banna's avatar
Vishnu Banna committed
606
  elif darknet:
607
608

    # pylint:disable=unbalanced-tuple-unpacking
Vishnu Banna's avatar
Vishnu Banna committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
    # if we are using the darknet loss we shoud nto propagate the
    # decoding of the box
    if box_type == 'scaled':
      (scaler, scaled_box,
       pred_box) = _darknet_new_coord_boxes(encoded_boxes, width, height,
                                            anchor_grid, grid_points, max_delta,
                                            scale_xy)
    else:
      (scaler, scaled_box,
       pred_box) = _darknet_boxes(encoded_boxes, width, height, anchor_grid,
                                  grid_points, max_delta, scale_xy)
  else:
    # if we are using the scaled loss we should propagate the decoding of
    # the boxes
    if box_type == 'scaled':
      (scaler, scaled_box,
       pred_box) = _new_coord_scale_boxes(encoded_boxes, width, height,
                                          anchor_grid, grid_points, scale_xy)
    else:
      (scaler, scaled_box, pred_box) = _scale_boxes(encoded_boxes, width,
                                                    height, anchor_grid,
                                                    grid_points, scale_xy)

  return (scaler, scaled_box, pred_box)