# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Yolo loss utility functions.""" import numpy as np import tensorflow as tf from official.projects.yolo.ops import box_ops from official.projects.yolo.ops import math_ops @tf.custom_gradient def sigmoid_bce(y, x_prime, label_smoothing): """Applies the Sigmoid Cross Entropy Loss. Implements the same derivative as that found in the Darknet C library. The derivative of this method is not the same as the standard binary cross entropy with logits function. The BCE with logits function equation is as follows: x = 1 / (1 + exp(-x_prime)) bce = -ylog(x) - (1 - y)log(1 - x) The standard BCE with logits function derivative is as follows: dloss = -y/x + (1-y)/(1-x) dsigmoid = x * (1 - x) dx = dloss * dsigmoid This derivative can be reduced simply to: dx = (-y + x) This simplification is used by the darknet library in order to improve training stability. The gradient is almost the same as tf.keras.losses.binary_crossentropy but varies slightly and yields different performance. Args: y: `Tensor` holding ground truth data. x_prime: `Tensor` holding the predictions prior to application of the sigmoid operation. label_smoothing: float value between 0.0 and 1.0 indicating the amount of smoothing to apply to the data. Returns: bce: Tensor of the be applied loss values. delta: callable function indicating the custom gradient for this operation. """ eps = 1e-9 x = tf.math.sigmoid(x_prime) y = tf.stop_gradient(y * (1 - label_smoothing) + 0.5 * label_smoothing) bce = -y * tf.math.log(x + eps) - (1 - y) * tf.math.log(1 - x + eps) def delta(dpass): x = tf.math.sigmoid(x_prime) dx = (-y + x) * dpass dy = tf.zeros_like(y) return dy, dx, 0.0 return bce, delta def apply_mask(mask, x, value=0): """This function is used for gradient masking. The YOLO loss function makes extensive use of dynamically shaped tensors. To allow this use case on the TPU while preserving the gradient correctly for back propagation we use this masking function to use a tf.where operation to hard set masked location to have a gradient and a value of zero. Args: mask: A `Tensor` with the same shape as x used to select values of importance. x: A `Tensor` with the same shape as mask that will be getting masked. value: `float` constant additive value. Returns: x: A masked `Tensor` with the same shape as x. """ mask = tf.cast(mask, tf.bool) masked = tf.where(mask, x, tf.zeros_like(x) + value) return masked def build_grid(indexes, truths, preds, ind_mask, update=False, grid=None): """This function is used to broadcast elements into the output shape. This function is used to broadcasts a list of truths into the correct index in the output shape. This is used for the ground truth map construction in the scaled loss and the classification map in the darknet loss. Args: indexes: A `Tensor` for the indexes truths: A `Tensor` for the ground truth. preds: A `Tensor` for the predictions. ind_mask: A `Tensor` for the index masks. update: A `bool` for updating the grid. grid: A `Tensor` for the grid. Returns: grid: A `Tensor` representing the augmented grid. """ # this function is used to broadcast all the indexes to the correct # into the correct ground truth mask, used for iou detection map # in the scaled loss and the classification mask in the darknet loss num_flatten = tf.shape(preds)[-1] # is there a way to verify that we are not on the CPU? ind_mask = tf.cast(ind_mask, indexes.dtype) # find all the batch indexes using the cumulated sum of a ones tensor # cumsum(ones) - 1 yeild the zero indexed batches bhep = tf.reduce_max(tf.ones_like(indexes), axis=-1, keepdims=True) bhep = tf.math.cumsum(bhep, axis=0) - 1 # concatnate the batch sizes to the indexes indexes = tf.concat([bhep, indexes], axis=-1) indexes = apply_mask(tf.cast(ind_mask, indexes.dtype), indexes) indexes = (indexes + (ind_mask - 1)) # mask truths truths = apply_mask(tf.cast(ind_mask, truths.dtype), truths) truths = (truths + (tf.cast(ind_mask, truths.dtype) - 1)) # reshape the indexes into the correct shape for the loss, # just flatten all indexes but the last indexes = tf.reshape(indexes, [-1, 4]) # also flatten the ground truth value on all axis but the last truths = tf.reshape(truths, [-1, num_flatten]) # build a zero grid in the samve shape as the predicitons if grid is None: grid = tf.zeros_like(preds) # remove invalid values from the truths that may have # come up from computation, invalid = nan and inf truths = math_ops.rm_nan_inf(truths) # scatter update the zero grid if update: grid = tf.tensor_scatter_nd_update(grid, indexes, truths) else: grid = tf.tensor_scatter_nd_max(grid, indexes, truths) # stop gradient and return to avoid TPU errors and save compute # resources return grid class GridGenerator: """Grid generator that generates anchor grids for box decoding.""" def __init__(self, anchors, scale_anchors=None): """Initialize Grid Generator. Args: anchors: A `List[List[int]]` for the anchor boxes that are used in the model at all levels. scale_anchors: An `int` for how much to scale this level to get the original input shape. """ self.dtype = tf.keras.backend.floatx() self._scale_anchors = scale_anchors self._anchors = tf.convert_to_tensor(anchors) return def _build_grid_points(self, lheight, lwidth, anchors, dtype): """Generate a grid of fixed grid edges for box center decoding.""" with tf.name_scope('center_grid'): y = tf.range(0, lheight) x = tf.range(0, lwidth) x_left = tf.tile( tf.transpose(tf.expand_dims(x, axis=-1), perm=[1, 0]), [lheight, 1]) y_left = tf.tile(tf.expand_dims(y, axis=-1), [1, lwidth]) x_y = tf.stack([x_left, y_left], axis=-1) x_y = tf.cast(x_y, dtype=dtype) num = tf.shape(anchors)[0] x_y = tf.expand_dims( tf.tile(tf.expand_dims(x_y, axis=-2), [1, 1, num, 1]), axis=0) return x_y def _build_anchor_grid(self, height, width, anchors, dtype): """Get the transformed anchor boxes for each dimention.""" with tf.name_scope('anchor_grid'): num = tf.shape(anchors)[0] anchors = tf.cast(anchors, dtype=dtype) anchors = tf.reshape(anchors, [1, 1, 1, num, 2]) anchors = tf.tile(anchors, [1, tf.cast(height, tf.int32), tf.cast(width, tf.int32), 1, 1]) return anchors def _extend_batch(self, grid, batch_size): return tf.tile(grid, [batch_size, 1, 1, 1, 1]) def __call__(self, height, width, batch_size, dtype=None): if dtype is None: self.dtype = tf.keras.backend.floatx() else: self.dtype = dtype grid_points = self._build_grid_points(height, width, self._anchors, self.dtype) anchor_grid = self._build_anchor_grid( height, width, tf.cast(self._anchors, self.dtype) / tf.cast(self._scale_anchors, self.dtype), self.dtype) grid_points = self._extend_batch(grid_points, batch_size) anchor_grid = self._extend_batch(anchor_grid, batch_size) return grid_points, anchor_grid TILE_SIZE = 50 class PairWiseSearch: """Apply a pairwise search between the ground truth and the labels. The goal is to indicate the locations where the predictions overlap with ground truth for dynamic ground truth associations. """ def __init__(self, iou_type='iou', any_match=True, min_conf=0.0, track_boxes=False, track_classes=False): """Initialization of Pair Wise Search. Args: iou_type: An `str` for the iou type to use. any_match: A `bool` for any match(no class match). min_conf: An `int` for minimum confidence threshold. track_boxes: A `bool` dynamic box assignment. track_classes: A `bool` dynamic class assignment. """ self.iou_type = iou_type self._any = any_match self._min_conf = min_conf self._track_boxes = track_boxes self._track_classes = track_classes return def box_iou(self, true_box, pred_box): # based on the type of loss, compute the iou loss for a box # compute_ indicated the type of iou to use if self.iou_type == 'giou': _, iou = box_ops.compute_giou(true_box, pred_box) elif self.iou_type == 'ciou': _, iou = box_ops.compute_ciou(true_box, pred_box) else: iou = box_ops.compute_iou(true_box, pred_box) return iou def _search_body(self, pred_box, pred_class, boxes, classes, running_boxes, running_classes, max_iou, idx): """Main search fn.""" # capture the batch size to be used, and gather a slice of # boxes from the ground truth. currently TILE_SIZE = 50, to # save memory batch_size = tf.shape(boxes)[0] box_slice = tf.slice(boxes, [0, idx * TILE_SIZE, 0], [batch_size, TILE_SIZE, 4]) # match the dimentions of the slice to the model predictions # shape: [batch_size, 1, 1, num, TILE_SIZE, 4] box_slice = tf.expand_dims(box_slice, axis=1) box_slice = tf.expand_dims(box_slice, axis=1) box_slice = tf.expand_dims(box_slice, axis=1) box_grid = tf.expand_dims(pred_box, axis=-2) # capture the classes class_slice = tf.slice(classes, [0, idx * TILE_SIZE], [batch_size, TILE_SIZE]) class_slice = tf.expand_dims(class_slice, axis=1) class_slice = tf.expand_dims(class_slice, axis=1) class_slice = tf.expand_dims(class_slice, axis=1) iou = self.box_iou(box_slice, box_grid) if self._min_conf > 0.0: if not self._any: class_grid = tf.expand_dims(pred_class, axis=-2) class_mask = tf.one_hot( tf.cast(class_slice, tf.int32), depth=tf.shape(pred_class)[-1], dtype=pred_class.dtype) class_mask = tf.reduce_any(tf.equal(class_mask, class_grid), axis=-1) else: class_mask = tf.reduce_max(pred_class, axis=-1, keepdims=True) class_mask = tf.cast(class_mask, iou.dtype) iou *= class_mask max_iou_ = tf.concat([max_iou, iou], axis=-1) max_iou = tf.reduce_max(max_iou_, axis=-1, keepdims=True) ind = tf.expand_dims(tf.argmax(max_iou_, axis=-1), axis=-1) if self._track_boxes: running_boxes = tf.expand_dims(running_boxes, axis=-2) box_slice = tf.zeros_like(running_boxes) + box_slice box_slice = tf.concat([running_boxes, box_slice], axis=-2) running_boxes = tf.gather_nd(box_slice, ind, batch_dims=4) if self._track_classes: running_classes = tf.expand_dims(running_classes, axis=-1) class_slice = tf.zeros_like(running_classes) + class_slice class_slice = tf.concat([running_classes, class_slice], axis=-1) running_classes = tf.gather_nd(class_slice, ind, batch_dims=4) return (pred_box, pred_class, boxes, classes, running_boxes, running_classes, max_iou, idx + 1) def __call__(self, pred_boxes, pred_classes, boxes, classes, clip_thresh=0.0): num_boxes = tf.shape(boxes)[-2] num_tiles = (num_boxes // TILE_SIZE) - 1 if self._min_conf > 0.0: pred_classes = tf.cast(pred_classes > self._min_conf, pred_classes.dtype) def _loop_cond(unused_pred_box, unused_pred_class, boxes, unused_classes, unused_running_boxes, unused_running_classes, unused_max_iou, idx): # check that the slice has boxes that all zeros batch_size = tf.shape(boxes)[0] box_slice = tf.slice(boxes, [0, idx * TILE_SIZE, 0], [batch_size, TILE_SIZE, 4]) return tf.logical_and(idx < num_tiles, tf.math.greater(tf.reduce_sum(box_slice), 0)) running_boxes = tf.zeros_like(pred_boxes) running_classes = tf.zeros_like(tf.reduce_sum(running_boxes, axis=-1)) max_iou = tf.zeros_like(tf.reduce_sum(running_boxes, axis=-1)) max_iou = tf.expand_dims(max_iou, axis=-1) (pred_boxes, pred_classes, boxes, classes, running_boxes, running_classes, max_iou, _) = tf.while_loop(_loop_cond, self._search_body, [ pred_boxes, pred_classes, boxes, classes, running_boxes, running_classes, max_iou, tf.constant(0) ]) mask = tf.cast(max_iou > clip_thresh, running_boxes.dtype) running_boxes *= mask running_classes *= tf.squeeze(mask, axis=-1) max_iou *= mask max_iou = tf.squeeze(max_iou, axis=-1) mask = tf.squeeze(mask, axis=-1) return (tf.stop_gradient(running_boxes), tf.stop_gradient(running_classes), tf.stop_gradient(max_iou), tf.stop_gradient(mask)) def average_iou(iou): """Computes the average intersection over union without counting locations. where the iou is zero. Args: iou: A `Tensor` representing the iou values. Returns: tf.stop_gradient(avg_iou): A `Tensor` representing average intersection over union. """ iou_sum = tf.reduce_sum(iou, axis=tf.range(1, tf.shape(tf.shape(iou))[0])) counts = tf.cast( tf.math.count_nonzero(iou, axis=tf.range(1, tf.shape(tf.shape(iou))[0])), iou.dtype) avg_iou = tf.reduce_mean(math_ops.divide_no_nan(iou_sum, counts)) return tf.stop_gradient(avg_iou) def _scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy): """Decodes models boxes applying and exponential to width and height maps.""" # split the boxes pred_xy = encoded_boxes[..., 0:2] pred_wh = encoded_boxes[..., 2:4] # build a scaling tensor to get the offset of th ebox relative to the image scaler = tf.convert_to_tensor([height, width, height, width]) scale_xy = tf.cast(scale_xy, encoded_boxes.dtype) # apply the sigmoid pred_xy = tf.math.sigmoid(pred_xy) # scale the centers and find the offset of each box relative to # their center pixel pred_xy = pred_xy * scale_xy - 0.5 * (scale_xy - 1) # scale the offsets and add them to the grid points or a tensor that is # the realtive location of each pixel box_xy = grid_points + pred_xy # scale the width and height of the predictions and corlate them # to anchor boxes box_wh = tf.math.exp(pred_wh) * anchor_grid # build the final predicted box scaled_box = tf.concat([box_xy, box_wh], axis=-1) pred_box = scaled_box / scaler # shift scaled boxes scaled_box = tf.concat([pred_xy, box_wh], axis=-1) return (scaler, scaled_box, pred_box) @tf.custom_gradient def _darknet_boxes(encoded_boxes, width, height, anchor_grid, grid_points, max_delta, scale_xy): """Wrapper for _scale_boxes to implement a custom gradient.""" (scaler, scaled_box, pred_box) = _scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy) def delta(unused_dy_scaler, dy_scaled, dy): dy_xy, dy_wh = tf.split(dy, 2, axis=-1) dy_xy_, dy_wh_ = tf.split(dy_scaled, 2, axis=-1) # add all the gradients that may have been applied to the # boxes and those that have been applied to the width and height dy_wh += dy_wh_ dy_xy += dy_xy_ # propagate the exponential applied to the width and height in # order to ensure the gradient propagated is of the correct # magnitude pred_wh = encoded_boxes[..., 2:4] dy_wh *= tf.math.exp(pred_wh) dbox = tf.concat([dy_xy, dy_wh], axis=-1) # apply the gradient clipping to xy and wh dbox = math_ops.rm_nan_inf(dbox) delta = tf.cast(max_delta, dbox.dtype) dbox = tf.clip_by_value(dbox, -delta, delta) return dbox, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 return (scaler, scaled_box, pred_box), delta def _new_coord_scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy): """Decodes models boxes by squaring and scaling the width and height maps.""" # split the boxes pred_xy = encoded_boxes[..., 0:2] pred_wh = encoded_boxes[..., 2:4] # build a scaling tensor to get the offset of th ebox relative to the image scaler = tf.convert_to_tensor([height, width, height, width]) scale_xy = tf.cast(scale_xy, pred_xy.dtype) # apply the sigmoid pred_xy = tf.math.sigmoid(pred_xy) pred_wh = tf.math.sigmoid(pred_wh) # scale the xy offset predictions according to the config pred_xy = pred_xy * scale_xy - 0.5 * (scale_xy - 1) # find the true offset from the grid points and the scaler # where the grid points are the relative offset of each pixel with # in the image box_xy = grid_points + pred_xy # decode the widht and height of the boxes and correlate them # to the anchor boxes box_wh = (2 * pred_wh)**2 * anchor_grid # build the final boxes scaled_box = tf.concat([box_xy, box_wh], axis=-1) pred_box = scaled_box / scaler # shift scaled boxes scaled_box = tf.concat([pred_xy, box_wh], axis=-1) return (scaler, scaled_box, pred_box) @tf.custom_gradient def _darknet_new_coord_boxes(encoded_boxes, width, height, anchor_grid, grid_points, max_delta, scale_xy): """Wrapper for _new_coord_scale_boxes to implement a custom gradient.""" (scaler, scaled_box, pred_box) = _new_coord_scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy) def delta(unused_dy_scaler, dy_scaled, dy): dy_xy, dy_wh = tf.split(dy, 2, axis=-1) dy_xy_, dy_wh_ = tf.split(dy_scaled, 2, axis=-1) # add all the gradients that may have been applied to the # boxes and those that have been applied to the width and height dy_wh += dy_wh_ dy_xy += dy_xy_ dbox = tf.concat([dy_xy, dy_wh], axis=-1) # apply the gradient clipping to xy and wh dbox = math_ops.rm_nan_inf(dbox) delta = tf.cast(max_delta, dbox.dtype) dbox = tf.clip_by_value(dbox, -delta, delta) return dbox, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 return (scaler, scaled_box, pred_box), delta def _anchor_free_scale_boxes(encoded_boxes, width, height, stride, grid_points, darknet=False): """Decode models boxes using FPN stride under anchor free conditions.""" del darknet # split the boxes pred_xy = encoded_boxes[..., 0:2] pred_wh = encoded_boxes[..., 2:4] # build a scaling tensor to get the offset of th ebox relative to the image scaler = tf.convert_to_tensor([height, width, height, width]) # scale the offsets and add them to the grid points or a tensor that is # the realtive location of each pixel box_xy = (grid_points + pred_xy) # scale the width and height of the predictions and corlate them # to anchor boxes box_wh = tf.math.exp(pred_wh) # build the final predicted box scaled_box = tf.concat([box_xy, box_wh], axis=-1) # properly scaling boxes gradeints scaled_box = scaled_box * tf.cast(stride, scaled_box.dtype) pred_box = scaled_box / tf.cast(scaler * stride, scaled_box.dtype) return (scaler, scaled_box, pred_box) def get_predicted_box(width, height, encoded_boxes, anchor_grid, grid_points, scale_xy, stride, darknet=False, box_type='original', max_delta=np.inf): """Decodes the predicted boxes from the model format to a usable format. This function decodes the model outputs into the [x, y, w, h] format for use in the loss function as well as for use within the detection generator. Args: width: A `float` scalar indicating the width of the prediction layer. height: A `float` scalar indicating the height of the prediction layer encoded_boxes: A `Tensor` of shape [..., height, width, 4] holding encoded boxes. anchor_grid: A `Tensor` of shape [..., 1, 1, 2] holding the anchor boxes organized for box decoding, box width and height. grid_points: A `Tensor` of shape [..., height, width, 2] holding the anchor boxes for decoding the box centers. scale_xy: A `float` scaler used to indicate the range for each center outside of its given [..., i, j, 4] index, where i and j are indexing pixels along the width and height of the predicted output map. stride: An `int` defining the amount of down stride realtive to the input image. darknet: A `bool` used to select between custom gradient and default autograd. box_type: An `str` indicating the type of box encoding that is being used. max_delta: A `float` scaler used for gradient clipping in back propagation. Returns: scaler: A `Tensor` of shape [4] returned to allow the scaling of the ground truth boxes to be of the same magnitude as the decoded predicted boxes. scaled_box: A `Tensor` of shape [..., height, width, 4] with the predicted boxes. pred_box: A `Tensor` of shape [..., height, width, 4] with the predicted boxes divided by the scaler parameter used to put all boxes in the [0, 1] range. """ if box_type == 'anchor_free': (scaler, scaled_box, pred_box) = _anchor_free_scale_boxes( encoded_boxes, width, height, stride, grid_points, darknet=darknet) elif darknet: # pylint:disable=unbalanced-tuple-unpacking # if we are using the darknet loss we shoud nto propagate the # decoding of the box if box_type == 'scaled': (scaler, scaled_box, pred_box) = _darknet_new_coord_boxes(encoded_boxes, width, height, anchor_grid, grid_points, max_delta, scale_xy) else: (scaler, scaled_box, pred_box) = _darknet_boxes(encoded_boxes, width, height, anchor_grid, grid_points, max_delta, scale_xy) else: # if we are using the scaled loss we should propagate the decoding of # the boxes if box_type == 'scaled': (scaler, scaled_box, pred_box) = _new_coord_scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy) else: (scaler, scaled_box, pred_box) = _scale_boxes(encoded_boxes, width, height, anchor_grid, grid_points, scale_xy) return (scaler, scaled_box, pred_box)