# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch from torch.nn import functional as F from torch import nn, Tensor import torchvision from torchvision.ops import boxes as box_ops from . import _utils as det_utils from .image_list import ImageList from torch.jit.annotations import List, Optional, Dict, Tuple @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): # type: (Tensor, int) -> Tuple[int, int] from torch.onnx import operators num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) pre_nms_top_n = torch.min(torch.cat( (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) return num_anchors, pre_nms_top_n class AnchorGenerator(nn.Module): __annotations__ = { "cell_anchors": Optional[List[torch.Tensor]], "_cache": Dict[str, List[torch.Tensor]] } """ Module that generates anchors for a set of feature maps and image sizes. The module support computing anchors at multiple sizes and aspect ratios per feature map. This module assumes aspect ratio = height / width for each anchor. sizes and aspect_ratios should have the same number of elements, and it should correspond to the number of feature maps. sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors per spatial location for feature map i. Arguments: sizes (Tuple[Tuple[int]]): aspect_ratios (Tuple[Tuple[float]]): """ def __init__( self, sizes=(128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0), ): super(AnchorGenerator, self).__init__() if not isinstance(sizes[0], (list, tuple)): # TODO change this sizes = tuple((s,) for s in sizes) if not isinstance(aspect_ratios[0], (list, tuple)): aspect_ratios = (aspect_ratios,) * len(sizes) assert len(sizes) == len(aspect_ratios) self.sizes = sizes self.aspect_ratios = aspect_ratios self.cell_anchors = None self._cache = {} # TODO: https://github.com/pytorch/pytorch/issues/26792 # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # This method assumes aspect ratio = height / width for an anchor. def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 scales = torch.as_tensor(scales, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) h_ratios = torch.sqrt(aspect_ratios) w_ratios = 1 / h_ratios ws = (w_ratios[:, None] * scales[None, :]).view(-1) hs = (h_ratios[:, None] * scales[None, :]).view(-1) base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 return base_anchors.round() def set_cell_anchors(self, dtype, device): # type: (int, Device) -> None # noqa: F821 if self.cell_anchors is not None: cell_anchors = self.cell_anchors assert cell_anchors is not None # suppose that all anchors have the same device # which is a valid assumption in the current state of the codebase if cell_anchors[0].device == device: return cell_anchors = [ self.generate_anchors( sizes, aspect_ratios, dtype, device ) for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) ] self.cell_anchors = cell_anchors def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. def grid_anchors(self, grid_sizes, strides): # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] anchors = [] cell_anchors = self.cell_anchors assert cell_anchors is not None for size, stride, base_anchors in zip( grid_sizes, strides, cell_anchors ): grid_height, grid_width = size stride_height, stride_width = stride device = base_anchors.device # For output anchor, compute [x_center, y_center, x_center, y_center] shifts_x = torch.arange( 0, grid_width, dtype=torch.float32, device=device ) * stride_width shifts_y = torch.arange( 0, grid_height, dtype=torch.float32, device=device ) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) # For every (base anchor, output anchor) pair, # offset each zero-centered base anchor by the center of the output anchor. anchors.append( (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) ) return anchors def cached_grid_anchors(self, grid_sizes, strides): # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] key = str(grid_sizes) + str(strides) if key in self._cache: return self._cache[key] anchors = self.grid_anchors(grid_sizes, strides) self._cache[key] = anchors return anchors def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) -> List[Tensor] grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] dtype, device = feature_maps[0].dtype, feature_maps[0].device strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors = torch.jit.annotate(List[List[torch.Tensor]], []) for i, (image_height, image_width) in enumerate(image_list.image_sizes): anchors_in_image = [] for anchors_per_feature_map in anchors_over_all_feature_maps: anchors_in_image.append(anchors_per_feature_map) anchors.append(anchors_in_image) anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] # Clear the cache in case that memory leaks. self._cache.clear() return anchors class RPNHead(nn.Module): """ Adds a simple RPN Head with classification and regression heads Arguments: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ def __init__(self, in_channels, num_anchors): super(RPNHead, self).__init__() self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.bbox_pred = nn.Conv2d( in_channels, num_anchors * 4, kernel_size=1, stride=1 ) for layer in self.children(): torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.constant_(layer.bias, 0) def forward(self, x): # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]] logits = [] bbox_reg = [] for feature in x: t = F.relu(self.conv(feature)) logits.append(self.cls_logits(t)) bbox_reg.append(self.bbox_pred(t)) return logits, bbox_reg def permute_and_flatten(layer, N, A, C, H, W): # type: (Tensor, int, int, int, int, int) -> Tensor layer = layer.view(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer def concat_box_prediction_layers(box_cls, box_regression): # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] box_cls_flattened = [] box_regression_flattened = [] # for each feature level, permute the outputs to make them be in the # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression for box_cls_per_level, box_regression_per_level in zip( box_cls, box_regression ): N, AxC, H, W = box_cls_per_level.shape Ax4 = box_regression_per_level.shape[1] A = Ax4 // 4 C = AxC // A box_cls_per_level = permute_and_flatten( box_cls_per_level, N, A, C, H, W ) box_cls_flattened.append(box_cls_per_level) box_regression_per_level = permute_and_flatten( box_regression_per_level, N, A, 4, H, W ) box_regression_flattened.append(box_regression_per_level) # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps # being concatenated as well) box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) return box_cls, box_regression class RegionProposalNetwork(torch.nn.Module): """ Implements Region Proposal Network (RPN). Arguments: anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature maps. head (nn.Module): module that computes the objectness and regression deltas fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be considered as positive during training of the RPN. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be considered as negative during training of the RPN. batch_size_per_image (int): number of anchors that are sampled during training of the RPN for computing the loss positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should contain two fields: training and testing, to allow for different values depending on training or evaluation nms_thresh (float): NMS threshold used for postprocessing the RPN proposals """ __annotations__ = { 'box_coder': det_utils.BoxCoder, 'proposal_matcher': det_utils.Matcher, 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, 'pre_nms_top_n': Dict[str, int], 'post_nms_top_n': Dict[str, int], } def __init__(self, anchor_generator, head, # fg_iou_thresh, bg_iou_thresh, batch_size_per_image, positive_fraction, # pre_nms_top_n, post_nms_top_n, nms_thresh): super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) # used during training self.box_similarity = box_ops.box_iou self.proposal_matcher = det_utils.Matcher( fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True, ) self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( batch_size_per_image, positive_fraction ) # used during testing self._pre_nms_top_n = pre_nms_top_n self._post_nms_top_n = post_nms_top_n self.nms_thresh = nms_thresh self.min_size = 1e-3 def pre_nms_top_n(self): if self.training: return self._pre_nms_top_n['training'] return self._pre_nms_top_n['testing'] def post_nms_top_n(self): if self.training: return self._post_nms_top_n['training'] return self._post_nms_top_n['testing'] def assign_targets_to_anchors(self, anchors, targets): # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] labels = [] matched_gt_boxes = [] for anchors_per_image, targets_per_image in zip(anchors, targets): gt_boxes = targets_per_image["boxes"] if gt_boxes.numel() == 0: # Background image (negative example) device = anchors_per_image.device matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device) labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device) else: match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) matched_idxs = self.proposal_matcher(match_quality_matrix) # get the targets corresponding GT for each proposal # NB: need to clamp the indices because we can have a single # GT in the image, and matched_idxs can be -2, which goes # out of bounds matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)] labels_per_image = matched_idxs >= 0 labels_per_image = labels_per_image.to(dtype=torch.float32) # Background (negative examples) bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD labels_per_image[bg_indices] = torch.tensor(0.0) # discard indices that are between thresholds inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS labels_per_image[inds_to_discard] = torch.tensor(-1.0) labels.append(labels_per_image) matched_gt_boxes.append(matched_gt_boxes_per_image) return labels, matched_gt_boxes def _get_top_n_idx(self, objectness, num_anchors_per_level): # type: (Tensor, List[int]) -> Tensor r = [] offset = 0 for ob in objectness.split(num_anchors_per_level, 1): if torchvision._is_tracing(): num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) else: num_anchors = ob.shape[1] pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) r.append(top_n_idx + offset) offset += num_anchors return torch.cat(r, dim=1) def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]] num_images = proposals.shape[0] device = proposals.device # do not backprop throught objectness objectness = objectness.detach() objectness = objectness.reshape(num_images, -1) levels = [ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) # select top_n boxes independently per level before applying nms top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) image_range = torch.arange(num_images, device=device) batch_idx = image_range[:, None] objectness = objectness[batch_idx, top_n_idx] levels = levels[batch_idx, top_n_idx] proposals = proposals[batch_idx, top_n_idx] final_boxes = [] final_scores = [] for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes): boxes = box_ops.clip_boxes_to_image(boxes, img_shape) keep = box_ops.remove_small_boxes(boxes, self.min_size) boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] # non-maximum suppression, independently done per level keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) # keep only topk scoring predictions keep = keep[:self.post_nms_top_n()] boxes, scores = boxes[keep], scores[keep] final_boxes.append(boxes) final_scores.append(scores) return final_boxes, final_scores def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] """ Arguments: objectness (Tensor) pred_bbox_deltas (Tensor) labels (List[Tensor]) regression_targets (List[Tensor]) Returns: objectness_loss (Tensor) box_loss (Tensor) """ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) objectness = objectness.flatten() labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) box_loss = det_utils.smooth_l1_loss( pred_bbox_deltas[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1 / 9, size_average=False, ) / (sampled_inds.numel()) objectness_loss = F.binary_cross_entropy_with_logits( objectness[sampled_inds], labels[sampled_inds] ) return objectness_loss, box_loss def forward(self, images, # type: ImageList features, # type: Dict[str, Tensor] targets=None # type: Optional[List[Dict[str, Tensor]]] ): # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] """ Arguments: images (ImageList): images for which we want to compute the predictions features (OrderedDict[Tensor]): features computed from the images that are used for computing the predictions. Each tensor in the list correspond to different feature levels targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). If provided, each element in the dict should contain a field `boxes`, with the locations of the ground-truth boxes. Returns: boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per image. losses (Dict[Tensor]): the losses for the model during training. During testing, it is an empty dict. """ # RPN uses all feature maps that are available features = list(features.values()) objectness, pred_bbox_deltas = self.head(features) anchors = self.anchor_generator(images, features) num_images = len(anchors) num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] objectness, pred_bbox_deltas = \ concat_box_prediction_layers(objectness, pred_bbox_deltas) # apply pred_bbox_deltas to anchors to obtain the decoded proposals # note that we detach the deltas because Faster R-CNN do not backprop through # the proposals proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) proposals = proposals.view(num_images, -1, 4) boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) losses = {} if self.training: assert targets is not None labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) loss_objectness, loss_rpn_box_reg = self.compute_loss( objectness, pred_bbox_deltas, labels, regression_targets) losses = { "loss_objectness": loss_objectness, "loss_rpn_box_reg": loss_rpn_box_reg, } return boxes, losses