# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Yolo Anchor labler.""" import numpy as np import tensorflow as tf from official.projects.yolo.ops import box_ops from official.projects.yolo.ops import loss_utils from official.projects.yolo.ops import preprocessing_ops INF = 10000000 def get_best_anchor(y_true, anchors, stride, width=1, height=1, iou_thresh=0.25, best_match_only=False, use_tie_breaker=True): """Get the correct anchor that is assoiciated with each box using IOU. Args: y_true: tf.Tensor[] for the list of bounding boxes in the yolo format. anchors: list or tensor for the anchor boxes to be used in prediction found via Kmeans. stride: `int` stride for the anchors. width: int for the image width. height: int for the image height. iou_thresh: `float` the minimum iou threshold to use for selecting boxes for each level. best_match_only: `bool` if the box only has one match and it is less than the iou threshold, when set to True, this match will be dropped as no anchors can be linked to it. use_tie_breaker: `bool` if there is many anchors for a given box, then attempt to use all of them, if False, only the first matching box will be used. Returns: tf.Tensor: y_true with the anchor associated with each ground truth box known """ with tf.name_scope('get_best_anchor'): width = tf.cast(width, dtype=tf.float32) height = tf.cast(height, dtype=tf.float32) scaler = tf.convert_to_tensor([width, height]) # scale to levels houts width and height true_wh = tf.cast(y_true[..., 2:4], dtype=tf.float32) * scaler # scale down from large anchor to small anchor type anchors = tf.cast(anchors, dtype=tf.float32) / stride k = tf.shape(anchors)[0] anchors = tf.concat([tf.zeros_like(anchors), anchors], axis=-1) truth_comp = tf.concat([tf.zeros_like(true_wh), true_wh], axis=-1) if iou_thresh >= 1.0: anchors = tf.expand_dims(anchors, axis=-2) truth_comp = tf.expand_dims(truth_comp, axis=-3) aspect = truth_comp[..., 2:4] / anchors[..., 2:4] aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect) aspect = tf.maximum(aspect, 1 / aspect) aspect = tf.where(tf.math.is_nan(aspect), tf.zeros_like(aspect), aspect) aspect = tf.reduce_max(aspect, axis=-1) values, indexes = tf.math.top_k( tf.transpose(-aspect, perm=[1, 0]), k=tf.cast(k, dtype=tf.int32), sorted=True) values = -values ind_mask = tf.cast(values < iou_thresh, dtype=indexes.dtype) else: truth_comp = box_ops.xcycwh_to_yxyx(truth_comp) anchors = box_ops.xcycwh_to_yxyx(anchors) iou_raw = box_ops.aggregated_comparitive_iou( truth_comp, anchors, iou_type=3, ) values, indexes = tf.math.top_k( iou_raw, k=tf.cast(k, dtype=tf.int32), sorted=True) ind_mask = tf.cast(values >= iou_thresh, dtype=indexes.dtype) # pad the indexs such that all values less than the thresh are -1 # add one, multiply the mask to zeros all the bad locations # subtract 1 makeing all the bad locations 0. if best_match_only: iou_index = ((indexes[..., 0:] + 1) * ind_mask[..., 0:]) - 1 elif use_tie_breaker: iou_index = tf.concat([ tf.expand_dims(indexes[..., 0], axis=-1), ((indexes[..., 1:] + 1) * ind_mask[..., 1:]) - 1 ], axis=-1) else: iou_index = tf.concat([ tf.expand_dims(indexes[..., 0], axis=-1), tf.zeros_like(indexes[..., 1:]) - 1 ], axis=-1) return tf.cast(iou_index, dtype=tf.float32), tf.cast(values, dtype=tf.float32) class YoloAnchorLabeler: """Anchor labeler for the Yolo Models.""" def __init__(self, anchors=None, anchor_free_level_limits=None, level_strides=None, center_radius=None, max_num_instances=200, match_threshold=0.25, best_matches_only=False, use_tie_breaker=True, darknet=False, dtype='float32'): """Initialization for anchor labler. Args: anchors: `Dict[List[Union[int, float]]]` values for each anchor box. anchor_free_level_limits: `List` the box sizes that will be allowed at each FPN level as is done in the FCOS and YOLOX paper for anchor free box assignment. level_strides: `Dict[int]` for how much the model scales down the images at the each level. center_radius: `Dict[float]` for radius around each box center to search for extra centers in each level. max_num_instances: `int` for the number of boxes to compute loss on. match_threshold: `float` indicating the threshold over which an anchor will be considered for prediction, at zero, all the anchors will be used and at 1.0 only the best will be used. for anchor thresholds larger than 1.0 we stop using the IOU for anchor comparison and resort directly to comparing the width and height, this is used for the scaled models. best_matches_only: `boolean` indicating how boxes are selected for optimization. use_tie_breaker: `boolean` indicating whether to use the anchor threshold value. darknet: `boolean` indicating which data pipeline to use. Setting to True swaps the pipeline to output images realtive to Yolov4 and older. dtype: `str` indicating the output datatype of the datapipeline selecting from {"float32", "float16", "bfloat16"}. """ self.anchors = anchors self.masks = self._get_mask() self.anchor_free_level_limits = self._get_level_limits( anchor_free_level_limits) if darknet and self.anchor_free_level_limits is None: center_radius = None self.keys = self.anchors.keys() if self.anchor_free_level_limits is not None: maxim = 2000 match_threshold = -0.01 self.num_instances = {key: maxim for key in self.keys} elif not darknet: self.num_instances = { key: (6 - i) * max_num_instances for i, key in enumerate(self.keys) } else: self.num_instances = {key: max_num_instances for key in self.keys} self.center_radius = center_radius self.level_strides = level_strides self.match_threshold = match_threshold self.best_matches_only = best_matches_only self.use_tie_breaker = use_tie_breaker self.dtype = dtype def _get_mask(self): """For each level get indexs of each anchor for box search across levels.""" masks = {} start = 0 minimum = int(min(self.anchors.keys())) maximum = int(max(self.anchors.keys())) for i in range(minimum, maximum + 1): per_scale = len(self.anchors[str(i)]) masks[str(i)] = list(range(start, per_scale + start)) start += per_scale return masks def _get_level_limits(self, level_limits): """For each level receptive feild range for anchor free box placement.""" if level_limits is not None: level_limits_dict = {} level_limits = [0.0] + level_limits + [np.inf] for i, key in enumerate(self.anchors.keys()): level_limits_dict[key] = level_limits[i:i + 2] else: level_limits_dict = None return level_limits_dict def _tie_breaking_search(self, anchors, mask, boxes, classes): """After search, link each anchor ind to the correct map in ground truth.""" mask = tf.cast(tf.reshape(mask, [1, 1, 1, -1]), anchors.dtype) anchors = tf.expand_dims(anchors, axis=-1) viable = tf.where(tf.squeeze(anchors == mask, axis=0)) gather_id, _, anchor_id = tf.split(viable, 3, axis=-1) boxes = tf.gather_nd(boxes, gather_id) classes = tf.gather_nd(classes, gather_id) classes = tf.expand_dims(classes, axis=-1) classes = tf.cast(classes, boxes.dtype) anchor_id = tf.cast(anchor_id, boxes.dtype) return boxes, classes, anchor_id def _get_anchor_id(self, key, boxes, classes, width, height, stride, iou_index=None): """Find the object anchor assignments in an anchor based paradigm.""" # find the best anchor anchors = self.anchors[key] num_anchors = len(anchors) if self.best_matches_only: # get the best anchor for each box iou_index, _ = get_best_anchor( boxes, anchors, stride, width=width, height=height, best_match_only=True, iou_thresh=self.match_threshold) mask = range(num_anchors) else: # search is done across FPN levels, get the mask of anchor indexes # corralated to this level. mask = self.masks[key] # search for the correct box to use (boxes, classes, anchors) = self._tie_breaking_search(iou_index, mask, boxes, classes) return boxes, classes, anchors, num_anchors def _get_centers(self, boxes, classes, anchors, width, height, scale_xy): """Find the object center assignments in an anchor based paradigm.""" offset = tf.cast(0.5 * (scale_xy - 1), boxes.dtype) grid_xy, _ = tf.split(boxes, 2, axis=-1) wh_scale = tf.cast(tf.convert_to_tensor([width, height]), boxes.dtype) grid_xy = grid_xy * wh_scale centers = tf.math.floor(grid_xy) if offset != 0.0: clamp = lambda x, ma: tf.maximum( # pylint:disable=g-long-lambda tf.minimum(x, tf.cast(ma, x.dtype)), tf.zeros_like(x)) grid_xy_index = grid_xy - centers positive_shift = ((grid_xy_index < offset) & (grid_xy > 1.)) negative_shift = ((grid_xy_index > (1 - offset)) & (grid_xy < (wh_scale - 1.))) zero, _ = tf.split(tf.ones_like(positive_shift), 2, axis=-1) shift_mask = tf.concat([zero, positive_shift, negative_shift], axis=-1) offset = tf.cast([[0, 0], [1, 0], [0, 1], [-1, 0], [0, -1]], offset.dtype) * offset num_shifts = tf.shape(shift_mask) num_shifts = num_shifts[-1] boxes = tf.tile(tf.expand_dims(boxes, axis=-2), [1, num_shifts, 1]) classes = tf.tile(tf.expand_dims(classes, axis=-2), [1, num_shifts, 1]) anchors = tf.tile(tf.expand_dims(anchors, axis=-2), [1, num_shifts, 1]) shift_mask = tf.cast(shift_mask, boxes.dtype) shift_ind = shift_mask * tf.range(0, num_shifts, dtype=boxes.dtype) shift_ind = shift_ind - (1 - shift_mask) shift_ind = tf.expand_dims(shift_ind, axis=-1) boxes_and_centers = tf.concat([boxes, classes, anchors, shift_ind], axis=-1) boxes_and_centers = tf.reshape(boxes_and_centers, [-1, 7]) _, center_ids = tf.split(boxes_and_centers, [6, 1], axis=-1) select = tf.where(center_ids >= 0) select, _ = tf.split(select, 2, axis=-1) boxes_and_centers = tf.gather_nd(boxes_and_centers, select) center_ids = tf.gather_nd(center_ids, select) center_ids = tf.cast(center_ids, tf.int32) shifts = tf.gather_nd(offset, center_ids) boxes, classes, anchors, _ = tf.split( boxes_and_centers, [4, 1, 1, 1], axis=-1) grid_xy, _ = tf.split(boxes, 2, axis=-1) centers = tf.math.floor(grid_xy * wh_scale - shifts) centers = clamp(centers, wh_scale - 1) x, y = tf.split(centers, 2, axis=-1) centers = tf.cast(tf.concat([y, x, anchors], axis=-1), tf.int32) return boxes, classes, centers def _get_anchor_free(self, key, boxes, classes, height, width, stride, center_radius): """Find the box assignements in an anchor free paradigm.""" level_limits = self.anchor_free_level_limits[key] gen = loss_utils.GridGenerator(anchors=[[1, 1]], scale_anchors=stride) grid_points = gen(width, height, 1, boxes.dtype)[0] grid_points = tf.squeeze(grid_points, axis=0) box_list = boxes class_list = classes grid_points = (grid_points + 0.5) * stride x_centers, y_centers = grid_points[..., 0], grid_points[..., 1] boxes *= (tf.convert_to_tensor([width, height, width, height]) * stride) tlbr_boxes = box_ops.xcycwh_to_yxyx(boxes) boxes = tf.reshape(boxes, [1, 1, -1, 4]) tlbr_boxes = tf.reshape(tlbr_boxes, [1, 1, -1, 4]) if self.use_tie_breaker: area = tf.reduce_prod(boxes[..., 2:], axis=-1) # check if the box is in the receptive feild of the this fpn level b_t = y_centers - tlbr_boxes[..., 0] b_l = x_centers - tlbr_boxes[..., 1] b_b = tlbr_boxes[..., 2] - y_centers b_r = tlbr_boxes[..., 3] - x_centers box_delta = tf.stack([b_t, b_l, b_b, b_r], axis=-1) if level_limits is not None: max_reg_targets_per_im = tf.reduce_max(box_delta, axis=-1) gt_min = max_reg_targets_per_im >= level_limits[0] gt_max = max_reg_targets_per_im <= level_limits[1] is_in_boxes = tf.logical_and(gt_min, gt_max) else: is_in_boxes = tf.reduce_min(box_delta, axis=-1) > 0.0 is_in_boxes_all = tf.reduce_any(is_in_boxes, axis=(0, 1), keepdims=True) # check if the center is in the receptive feild of the this fpn level c_t = y_centers - (boxes[..., 1] - center_radius * stride) c_l = x_centers - (boxes[..., 0] - center_radius * stride) c_b = (boxes[..., 1] + center_radius * stride) - y_centers c_r = (boxes[..., 0] + center_radius * stride) - x_centers centers_delta = tf.stack([c_t, c_l, c_b, c_r], axis=-1) is_in_centers = tf.reduce_min(centers_delta, axis=-1) > 0.0 is_in_centers_all = tf.reduce_any(is_in_centers, axis=(0, 1), keepdims=True) # colate all masks to get the final locations is_in_index = tf.logical_or(is_in_boxes_all, is_in_centers_all) is_in_boxes_and_center = tf.logical_and(is_in_boxes, is_in_centers) is_in_boxes_and_center = tf.logical_and(is_in_index, is_in_boxes_and_center) if self.use_tie_breaker: boxes_all = tf.cast(is_in_boxes_and_center, area.dtype) boxes_all = ((boxes_all * area) + ((1 - boxes_all) * INF)) boxes_min = tf.reduce_min(boxes_all, axis=-1, keepdims=True) boxes_min = tf.where(boxes_min == INF, -1.0, boxes_min) is_in_boxes_and_center = boxes_all == boxes_min # construct the index update grid reps = tf.reduce_sum(tf.cast(is_in_boxes_and_center, tf.int16), axis=-1) indexes = tf.cast(tf.where(is_in_boxes_and_center), tf.int32) y, x, t = tf.split(indexes, 3, axis=-1) boxes = tf.gather_nd(box_list, t) classes = tf.cast(tf.gather_nd(class_list, t), boxes.dtype) reps = tf.gather_nd(reps, tf.concat([y, x], axis=-1)) reps = tf.cast(tf.expand_dims(reps, axis=-1), boxes.dtype) classes = tf.cast(tf.expand_dims(classes, axis=-1), boxes.dtype) conf = tf.ones_like(classes) # return the samples and the indexes samples = tf.concat([boxes, conf, classes], axis=-1) indexes = tf.concat([y, x, tf.zeros_like(t)], axis=-1) return indexes, samples def build_label_per_path(self, key, boxes, classes, width, height, iou_index=None): """Builds the labels for one path.""" stride = self.level_strides[key] scale_xy = self.center_radius[key] if self.center_radius is not None else 1 width = tf.cast(width // stride, boxes.dtype) height = tf.cast(height // stride, boxes.dtype) if self.anchor_free_level_limits is None: (boxes, classes, anchors, num_anchors) = self._get_anchor_id( key, boxes, classes, width, height, stride, iou_index=iou_index) boxes, classes, centers = self._get_centers(boxes, classes, anchors, width, height, scale_xy) ind_mask = tf.ones_like(classes) updates = tf.concat([boxes, ind_mask, classes], axis=-1) else: num_anchors = 1 (centers, updates) = self._get_anchor_free(key, boxes, classes, height, width, stride, scale_xy) boxes, ind_mask, classes = tf.split(updates, [4, 1, 1], axis=-1) width = tf.cast(width, tf.int32) height = tf.cast(height, tf.int32) full = tf.zeros([height, width, num_anchors, 1], dtype=classes.dtype) full = tf.tensor_scatter_nd_add(full, centers, ind_mask) num_instances = int(self.num_instances[key]) centers = preprocessing_ops.pad_max_instances( centers, num_instances, pad_value=0, pad_axis=0) updates = preprocessing_ops.pad_max_instances( updates, num_instances, pad_value=0, pad_axis=0) updates = tf.cast(updates, self.dtype) full = tf.cast(full, self.dtype) return centers, updates, full def __call__(self, boxes, classes, width, height): """Builds the labels for a single image, not functional in batch mode. Args: boxes: `Tensor` of shape [None, 4] indicating the object locations in an image. classes: `Tensor` of shape [None] indicating the each objects classes. width: `int` for the images width. height: `int` for the images height. Returns: centers: `Tensor` of shape [None, 3] of indexes in the final grid where boxes are located. updates: `Tensor` of shape [None, 8] the value to place in the final grid. full: `Tensor` of [width/stride, height/stride, num_anchors, 1] holding a mask of where boxes are locates for confidence losses. """ indexes = {} updates = {} true_grids = {} iou_index = None boxes = box_ops.yxyx_to_xcycwh(boxes) if not self.best_matches_only and self.anchor_free_level_limits is None: # stitch and search boxes across fpn levels anchorsvec = [] for stitch in self.anchors: anchorsvec.extend(self.anchors[stitch]) stride = tf.cast([width, height], boxes.dtype) # get the best anchor for each box iou_index, _ = get_best_anchor( boxes, anchorsvec, stride, width=1.0, height=1.0, best_match_only=False, use_tie_breaker=self.use_tie_breaker, iou_thresh=self.match_threshold) for key in self.keys: indexes[key], updates[key], true_grids[key] = self.build_label_per_path( key, boxes, classes, width, height, iou_index=iou_index) return indexes, updates, true_grids