# Copyright (c) OpenMMLab. All rights reserved. import warnings from typing import List, Optional, Tuple import numpy as np import torch from mmcv import ConfigDict from mmcv.runner import BaseModule, force_fp32 from mmengine.data import InstanceData from torch import Tensor from torch import nn as nn from mmdet3d.core import (Det3DDataSample, PseudoSampler, box3d_multiclass_nms, limit_period, merge_aug_bboxes_3d, xywhr2xyxyr) from mmdet3d.registry import MODELS, TASK_UTILS from mmdet.core import multi_apply from .train_mixins import AnchorTrainMixin @MODELS.register_module() class Anchor3DHead(BaseModule, AnchorTrainMixin): """Anchor head for SECOND/PointPillars/MVXNet/PartA2. Args: num_classes (int): Number of classes. in_channels (int): Number of channels in the input feature map. train_cfg (dict): Train configs. test_cfg (dict): Test configs. feat_channels (int): Number of channels of the feature map. use_direction_classifier (bool): Whether to add a direction classifier. anchor_generator(dict): Config dict of anchor generator. assigner_per_size (bool): Whether to do assignment for each separate anchor size. assign_per_class (bool): Whether to do assignment for each class. diff_rad_by_sin (bool): Whether to change the difference into sin difference for box regression loss. dir_offset (float | int): The offset of BEV rotation angles. (TODO: may be moved into box coder) dir_limit_offset (float | int): The limited range of BEV rotation angles. (TODO: may be moved into box coder) bbox_coder (dict): Config dict of box coders. loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. loss_dir (dict): Config of direction classifier loss. """ def __init__(self, num_classes: int, in_channels: int, train_cfg: dict, test_cfg: dict, feat_channels: int = 256, use_direction_classifier: bool = True, anchor_generator: dict = dict( type='Anchor3DRangeGenerator', range=[0, -39.68, -1.78, 69.12, 39.68, -1.78], strides=[2], sizes=[[3.9, 1.6, 1.56]], rotations=[0, 1.57], custom_values=[], reshape_out=False), assigner_per_size: bool = False, assign_per_class: bool = False, diff_rad_by_sin: bool = True, dir_offset: float = -np.pi / 2, dir_limit_offset: int = 0, bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'), loss_cls: dict = dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox: dict = dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), loss_dir: dict = dict( type='CrossEntropyLoss', loss_weight=0.2), init_cfg: Optional[dict] = None) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.num_classes = num_classes self.feat_channels = feat_channels self.diff_rad_by_sin = diff_rad_by_sin self.use_direction_classifier = use_direction_classifier self.train_cfg = train_cfg self.test_cfg = test_cfg self.assigner_per_size = assigner_per_size self.assign_per_class = assign_per_class self.dir_offset = dir_offset self.dir_limit_offset = dir_limit_offset warnings.warn( 'dir_offset and dir_limit_offset will be depressed and be ' 'incorporated into box coder in the future') self.fp16_enabled = False # build anchor generator self.prior_generator = TASK_UTILS.build(anchor_generator) # In 3D detection, the anchor stride is connected with anchor size self.num_anchors = self.prior_generator.num_base_anchors # build box coder self.bbox_coder = TASK_UTILS.build(bbox_coder) self.box_code_size = self.bbox_coder.code_size # build loss function self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) self.sampling = loss_cls['type'] not in [ 'mmdet.FocalLoss', 'mmdet.GHMC' ] if not self.use_sigmoid_cls: self.num_classes += 1 self.loss_cls = MODELS.build(loss_cls) self.loss_bbox = MODELS.build(loss_bbox) self.loss_dir = MODELS.build(loss_dir) self.fp16_enabled = False self._init_layers() self._init_assigner_sampler() if init_cfg is None: self.init_cfg = dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)) def _init_assigner_sampler(self): """Initialize the target assigner and sampler of the head.""" if self.train_cfg is None: return if self.sampling: self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler) else: self.bbox_sampler = PseudoSampler() if isinstance(self.train_cfg.assigner, dict): self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner) elif isinstance(self.train_cfg.assigner, list): self.bbox_assigner = [ TASK_UTILS.build(res) for res in self.train_cfg.assigner ] def _init_layers(self): """Initialize neural network layers of the head.""" self.cls_out_channels = self.num_anchors * self.num_classes self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1) self.conv_reg = nn.Conv2d(self.feat_channels, self.num_anchors * self.box_code_size, 1) if self.use_direction_classifier: self.conv_dir_cls = nn.Conv2d(self.feat_channels, self.num_anchors * 2, 1) def forward_single(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Forward function on a single-scale feature map. Args: x (torch.Tensor): Input features. Returns: tuple[torch.Tensor]: Contain score of each class, bbox regression and direction classification predictions. """ cls_score = self.conv_cls(x) bbox_pred = self.conv_reg(x) dir_cls_preds = None if self.use_direction_classifier: dir_cls_preds = self.conv_dir_cls(x) return cls_score, bbox_pred, dir_cls_preds def forward(self, feats: List[Tensor]) -> Tuple[list]: """Forward pass. Args: feats (list[torch.Tensor]): Multi-level features, e.g., features produced by FPN. Returns: tuple[list[torch.Tensor]]: Multi-level class score, bbox and direction predictions. """ return multi_apply(self.forward_single, feats) def forward_train(self, feats: List[Tensor], batch_data_samples: List[Det3DDataSample], proposal_cfg: Optional[ConfigDict] = None, **kwargs): """ Args: feats (list[Tensor]): Features from FPN. batch_data_samples (list[:obj:`Det3DDataSample`]): Each item contains the meta information of each sample and corresponding annotations. proposal_cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. Returns: tuple or Tensor: When `proposal_cfg` is None, the detector is a \ normal one-stage detector, The return value is the losses. - losses: (dict[str, Tensor]): A dictionary of loss components. When the `proposal_cfg` is not None, the head is used as a `rpn_head`, the return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - results_list (list[:obj:`InstanceData`]): Detection results of each input after the post process. Each item usually contains following keys.Det3DDataSample - scores_3d (Tensor): Classification scores, has a shape (num_instances, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes, contains a tensor with shape (num_instances, 7). """ outs = self.forward(feats) batch_gt_instance_3d = [] batch_gt_instances_ignore = [] batch_input_metas = [] for data_sample in batch_data_samples: batch_input_metas.append(data_sample.metainfo) batch_gt_instance_3d.append(data_sample.gt_instances_3d) if 'ignored_instances' in data_sample: batch_gt_instances_ignore.append(data_sample.ignored_instances) else: batch_gt_instances_ignore.append(None) loss_inputs = outs + (batch_gt_instance_3d, batch_input_metas) losses = self.loss( *loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore) if proposal_cfg is None: return losses else: batch_img_metas = [ data_sample.metainfo for data_sample in batch_data_samples ] results_list = self.get_results( *outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg) return losses, results_list def simple_test(self, feats: Tuple[Tensor], batch_input_metas: List[dict], rescale: bool = False) -> List[InstanceData]: """Test function without test-time augmentation. Args: feats (tuple[torch.Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_input_metas (list[dict]): List of image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[:obj:`InstanceData`]: Detection results of each input after the post process. Each item usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instances, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, contains a tensor with shape (num_instances, 7). """ outs = self.forward(feats) results_list = self.get_results( *outs, input_metas=batch_input_metas, rescale=rescale) return results_list def aug_test(self, aug_batch_feats, aug_batch_input_metas, rescale=False, **kwargs): aug_bboxes = [] # only support aug_test for one sample for x, input_meta in zip(aug_batch_feats, aug_batch_input_metas): outs = self.forward(x) bbox_list = self.get_results(*outs, [input_meta], rescale=rescale) bbox_dict = dict( bboxes_3d=bbox_list[0].bboxes_3d, scores_3d=bbox_list[0].scores_3d, labels_3d=bbox_list[0].labels_3d) aug_bboxes.append(bbox_dict) # after merging, bboxes will be rescaled to the original image size merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, aug_batch_input_metas, self.test_cfg) return [merged_bboxes] def get_anchors(self, featmap_sizes: List[tuple], input_metas: List[dict], device: str = 'cuda') -> list: """Get anchors according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. input_metas (list[dict]): contain pcd and img's meta info. device (str): device of current module. Returns: list[list[torch.Tensor]]: Anchors of each image, valid flags of each image. """ num_imgs = len(input_metas) # since feature map sizes of all images are the same, we only compute # anchors for one time multi_level_anchors = self.prior_generator.grid_anchors( featmap_sizes, device=device) anchor_list = [multi_level_anchors for _ in range(num_imgs)] return anchor_list def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels, label_weights, bbox_targets, bbox_weights, dir_targets, dir_weights, num_total_samples): """Calculate loss of Single-level results. Args: cls_score (torch.Tensor): Class score in single-level. bbox_pred (torch.Tensor): Bbox prediction in single-level. dir_cls_preds (torch.Tensor): Predictions of direction class in single-level. labels (torch.Tensor): Labels of class. label_weights (torch.Tensor): Weights of class loss. bbox_targets (torch.Tensor): Targets of bbox predictions. bbox_weights (torch.Tensor): Weights of bbox loss. dir_targets (torch.Tensor): Targets of direction predictions. dir_weights (torch.Tensor): Weights of direction loss. num_total_samples (int): The number of valid samples. Returns: tuple[torch.Tensor]: Losses of class, bbox and direction, respectively. """ # classification loss if num_total_samples is None: num_total_samples = int(cls_score.shape[0]) labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.num_classes) assert labels.max().item() <= self.num_classes loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=num_total_samples) # regression loss bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, self.box_code_size) bbox_targets = bbox_targets.reshape(-1, self.box_code_size) bbox_weights = bbox_weights.reshape(-1, self.box_code_size) bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero( as_tuple=False).reshape(-1) num_pos = len(pos_inds) pos_bbox_pred = bbox_pred[pos_inds] pos_bbox_targets = bbox_targets[pos_inds] pos_bbox_weights = bbox_weights[pos_inds] # dir loss if self.use_direction_classifier: dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).reshape(-1, 2) dir_targets = dir_targets.reshape(-1) dir_weights = dir_weights.reshape(-1) pos_dir_cls_preds = dir_cls_preds[pos_inds] pos_dir_targets = dir_targets[pos_inds] pos_dir_weights = dir_weights[pos_inds] if num_pos > 0: code_weight = self.train_cfg.get('code_weight', None) if code_weight: pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor( code_weight) if self.diff_rad_by_sin: pos_bbox_pred, pos_bbox_targets = self.add_sin_difference( pos_bbox_pred, pos_bbox_targets) loss_bbox = self.loss_bbox( pos_bbox_pred, pos_bbox_targets, pos_bbox_weights, avg_factor=num_total_samples) # direction classification loss loss_dir = None if self.use_direction_classifier: loss_dir = self.loss_dir( pos_dir_cls_preds, pos_dir_targets, pos_dir_weights, avg_factor=num_total_samples) else: loss_bbox = pos_bbox_pred.sum() if self.use_direction_classifier: loss_dir = pos_dir_cls_preds.sum() return loss_cls, loss_bbox, loss_dir @staticmethod def add_sin_difference(boxes1: Tensor, boxes2: Tensor) -> tuple: """Convert the rotation difference to difference in sine function. Args: boxes1 (torch.Tensor): Original Boxes in shape (NxC), where C>=7 and the 7th dimension is rotation dimension. boxes2 (torch.Tensor): Target boxes in shape (NxC), where C>=7 and the 7th dimension is rotation dimension. Returns: tuple[torch.Tensor]: ``boxes1`` and ``boxes2`` whose 7th dimensions are changed. """ rad_pred_encoding = torch.sin(boxes1[..., 6:7]) * torch.cos( boxes2[..., 6:7]) rad_tg_encoding = torch.cos(boxes1[..., 6:7]) * torch.sin(boxes2[..., 6:7]) boxes1 = torch.cat( [boxes1[..., :6], rad_pred_encoding, boxes1[..., 7:]], dim=-1) boxes2 = torch.cat([boxes2[..., :6], rad_tg_encoding, boxes2[..., 7:]], dim=-1) return boxes1, boxes2 @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) def loss(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], dir_cls_preds: List[Tensor], batch_gt_instances_3d: List[InstanceData], batch_input_metas: List[dict], batch_gt_instances_ignore: List[InstanceData] = None) -> dict: """Calculate losses. Args: cls_scores (list[torch.Tensor]): Multi-level class scores. bbox_preds (list[torch.Tensor]): Multi-level bbox predictions. dir_cls_preds (list[torch.Tensor]): Multi-level direction class predictions. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of gt_instances. It usually includes ``bboxes`` and ``labels`` attributes. batch_input_metas (list[dict]): Contain pcd and img's meta info. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, list[torch.Tensor]]: Classification, bbox, and direction losses of each level. - loss_cls (list[torch.Tensor]): Classification losses. - loss_bbox (list[torch.Tensor]): Box regression losses. - loss_dir (list[torch.Tensor]): Direction classification losses. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list = self.get_anchors( featmap_sizes, batch_input_metas, device=device) label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 cls_reg_targets = self.anchor_target_3d( anchor_list, batch_gt_instances_3d, batch_input_metas, batch_gt_instances_ignore=batch_gt_instances_ignore, num_classes=self.num_classes, label_channels=label_channels, sampling=self.sampling) if cls_reg_targets is None: return None (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, dir_targets_list, dir_weights_list, num_total_pos, num_total_neg) = cls_reg_targets num_total_samples = ( num_total_pos + num_total_neg if self.sampling else num_total_pos) # num_total_samples = None losses_cls, losses_bbox, losses_dir = multi_apply( self.loss_single, cls_scores, bbox_preds, dir_cls_preds, labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, dir_targets_list, dir_weights_list, num_total_samples=num_total_samples) return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir) def get_results(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], dir_cls_preds: List[Tensor], input_metas: List[dict], cfg: ConfigDict = None, rescale: list = False) -> List[InstanceData]: """Get results of anchor head. Args: cls_scores (list[torch.Tensor]): Multi-level class scores. bbox_preds (list[torch.Tensor]): Multi-level bbox predictions. dir_cls_preds (list[torch.Tensor]): Multi-level direction class predictions. input_metas (list[dict]): Contain pcd and img's meta info. cfg (:obj:`ConfigDict`): Training or testing config. rescale (list[torch.Tensor]): Whether th rescale bbox. Returns: list[:obj:`InstanceData`]: Instance prediction results of each sample after the post process. Each item usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instance, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes, contains a tensor with shape (num_instances, 7). """ assert len(cls_scores) == len(bbox_preds) assert len(cls_scores) == len(dir_cls_preds) num_levels = len(cls_scores) featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] device = cls_scores[0].device mlvl_anchors = self.prior_generator.grid_anchors( featmap_sizes, device=device) mlvl_anchors = [ anchor.reshape(-1, self.box_code_size) for anchor in mlvl_anchors ] result_list = [] for img_id in range(len(input_metas)): cls_score_list = [ cls_scores[i][img_id].detach() for i in range(num_levels) ] bbox_pred_list = [ bbox_preds[i][img_id].detach() for i in range(num_levels) ] dir_cls_pred_list = [ dir_cls_preds[i][img_id].detach() for i in range(num_levels) ] input_meta = input_metas[img_id] proposals = self._get_results_single(cls_score_list, bbox_pred_list, dir_cls_pred_list, mlvl_anchors, input_meta, cfg, rescale) result_list.append(proposals) return result_list def _get_results_single(self, cls_scores: Tensor, bbox_preds: Tensor, dir_cls_preds: Tensor, mlvl_anchors: List[Tensor], input_meta: List[dict], cfg: ConfigDict = None, rescale: bool = False) -> InstanceData: """Get results of single branch. Args: cls_scores (torch.Tensor): Class score in single batch. bbox_preds (torch.Tensor): Bbox prediction in single batch. dir_cls_preds (torch.Tensor): Predictions of direction class in single batch. mlvl_anchors (List[torch.Tensor]): Multi-level anchors in single batch. input_meta (list[dict]): Contain pcd and img's meta info. cfg (:obj:`ConfigDict`): Training or testing config. rescale (list[torch.Tensor]): whether th rescale bbox. Returns: :obj:`InstanceData`: Detection results of each sample after the post process. Each item usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instance, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes, contains a tensor with shape (num_instances, 7). """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) mlvl_bboxes = [] mlvl_scores = [] mlvl_dir_scores = [] for cls_score, bbox_pred, dir_cls_pred, anchors in zip( cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert cls_score.size()[-2:] == dir_cls_pred.size()[-2:] dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2) dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.num_classes) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, self.box_code_size) nms_pre = cfg.get('nms_pre', -1) if nms_pre > 0 and scores.shape[0] > nms_pre: if self.use_sigmoid_cls: max_scores, _ = scores.max(dim=1) else: max_scores, _ = scores[:, :-1].max(dim=1) _, topk_inds = max_scores.topk(nms_pre) anchors = anchors[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] scores = scores[topk_inds, :] dir_cls_score = dir_cls_score[topk_inds] bboxes = self.bbox_coder.decode(anchors, bbox_pred) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_dir_scores.append(dir_cls_score) mlvl_bboxes = torch.cat(mlvl_bboxes) mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( mlvl_bboxes, box_dim=self.box_code_size).bev) mlvl_scores = torch.cat(mlvl_scores) mlvl_dir_scores = torch.cat(mlvl_dir_scores) if self.use_sigmoid_cls: # Add a dummy background class to the front when using sigmoid padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) score_thr = cfg.get('score_thr', 0) results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_scores, score_thr, cfg.max_num, cfg, mlvl_dir_scores) bboxes, scores, labels, dir_scores = results if bboxes.shape[0] > 0: dir_rot = limit_period(bboxes[..., 6] - self.dir_offset, self.dir_limit_offset, np.pi) bboxes[..., 6] = ( dir_rot + self.dir_offset + np.pi * dir_scores.to(bboxes.dtype)) bboxes = input_meta['box_type_3d'](bboxes, box_dim=self.box_code_size) results = InstanceData() results.bboxes_3d = bboxes results.scores_3d = scores results.labels_3d = labels return results