import torch from .geometry import bbox_overlaps class BBoxAssigner(object): """Assign a corresponding gt bbox or background to each bbox. Each proposals will be assigned with `-1`, `0`, or a positive integer indicating the ground truth index. - -1: don't care - 0: negative sample, no assigned gt - positive integer: positive sample, index (1-based) of assigned gt Args: pos_iou_thr (float): IoU threshold for positive bboxes. neg_iou_thr (float or tuple): IoU threshold for negative bboxes. min_pos_iou (float): Minimum iou for a bbox to be considered as a positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN, it is usually set as pos_iou_thr ignore_iof_thr (float): IoF threshold for ignoring bboxes (if `gt_bboxes_ignore` is specified). Negative values mean not ignoring any bboxes. """ def __init__(self, pos_iou_thr, neg_iou_thr, min_pos_iou=.0, ignore_iof_thr=-1): self.pos_iou_thr = pos_iou_thr self.neg_iou_thr = neg_iou_thr self.min_pos_iou = min_pos_iou self.ignore_iof_thr = ignore_iof_thr def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): """Assign gt to bboxes. This method assign a gt bbox to every bbox (proposal/anchor), each bbox will be assigned with -1, 0, or a positive number. -1 means don't care, 0 means negative sample, positive number is the index (1-based) of assigned gt. The assignment is done in following steps, the order matters. 1. assign every bbox to -1 2. assign proposals whose iou with all gts < neg_iou_thr to 0 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, assign it to that bbox 4. for each gt bbox, assign its nearest proposals (may be more than one) to itself Args: bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are labelled as `ignored`, e.g., crowd boxes in COCO. gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). Returns: :obj:`AssignResult`: The assign result. """ if bboxes.shape[0] == 0 or gt_bboxes.shape[0] == 0: raise ValueError('No gt or bboxes') bboxes = bboxes[:, :4] overlaps = bbox_overlaps(bboxes, gt_bboxes) if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( gt_bboxes_ignore.numel() > 0): ignore_overlaps = bbox_overlaps( bboxes, gt_bboxes_ignore, mode='iof') ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) ignore_bboxes_inds = torch.nonzero( ignore_max_overlaps > self.ignore_iof_thr).squeeze() if ignore_bboxes_inds.numel() > 0: overlaps[ignore_bboxes_inds[:, 0], :] = -1 assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) return assign_result def assign_wrt_overlaps(self, overlaps, gt_labels=None): """Assign w.r.t. the overlaps of bboxes with gts. Args: overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes, shape(n, k). gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ). Returns: :obj:`AssignResult`: The assign result. """ if overlaps.numel() == 0: raise ValueError('No gt or proposals') num_bboxes, num_gts = overlaps.size(0), overlaps.size(1) # 1. assign -1 by default assigned_gt_inds = overlaps.new_full( (num_bboxes, ), -1, dtype=torch.long) assert overlaps.size() == (num_bboxes, num_gts) # for each anchor, which gt best overlaps with it # for each anchor, the max iou of all gts max_overlaps, argmax_overlaps = overlaps.max(dim=1) # for each gt, which anchor best overlaps with it # for each gt, the max iou of all proposals gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0) # 2. assign negative: below if isinstance(self.neg_iou_thr, float): assigned_gt_inds[(max_overlaps >= 0) & (max_overlaps < self.neg_iou_thr)] = 0 elif isinstance(self.neg_iou_thr, tuple): assert len(self.neg_iou_thr) == 2 assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0]) & (max_overlaps < self.neg_iou_thr[1])] = 0 # 3. assign positive: above positive IoU threshold pos_inds = max_overlaps >= self.pos_iou_thr assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 # 4. assign fg: for each gt, proposals with highest IoU for i in range(num_gts): if gt_max_overlaps[i] >= self.min_pos_iou: assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1 if gt_labels is not None: assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, )) pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze() if pos_inds.numel() > 0: assigned_labels[pos_inds] = gt_labels[ assigned_gt_inds[pos_inds] - 1] else: assigned_labels = None return AssignResult( num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) class AssignResult(object): def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): self.num_gts = num_gts self.gt_inds = gt_inds self.max_overlaps = max_overlaps self.labels = labels def add_gt_(self, gt_labels): self_inds = torch.arange( 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) self.gt_inds = torch.cat([self_inds, self.gt_inds]) self.max_overlaps = torch.cat( [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) if self.labels is not None: self.labels = torch.cat([gt_labels, self.labels])