from __future__ import division import numpy as np import torch from mmdet3d.core import box_torch_ops, boxes3d_to_bev_torch_lidar from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu from mmdet.models import HEADS from .anchor3d_head import Anchor3DHead @HEADS.register_module() class PartA2RPNHead(Anchor3DHead): """RPN head for PartA2 Note: The main difference between the PartA2 RPN head and the Anchor3DHead lies in their output during inference. PartA2 RPN head further returns the original classification score for the second stage since the bbox head in RoI head does not do classification task. Different from RPN heads in 2D detectors, this RPN head does multi-class classification task and uses FocalLoss like the SECOND and PointPillars do. But this head uses class agnostic nms rather than multi-class nms. Args: num_classes (int): Number of classes. in_channels (int): Number of channels in the input feature map. train_cfg (dict): train configs test_cfg (dict): test configs feat_channels (int): Number of channels of the feature map. use_direction_classifier (bool): Whether to add a direction classifier. anchor_generator(dict): Config dict of anchor generator. assigner_per_size (bool): Whether to do assignment for each separate anchor size. assign_per_class (bool): Whether to do assignment for each class. diff_rad_by_sin (bool): Whether to change the difference into sin difference for box regression loss. dir_offset (float | int): The offset of BEV rotation angles (TODO: may be moved into box coder) dirlimit_offset (float | int): The limited range of BEV rotation angles (TODO: may be moved into box coder) box_coder (dict): Config dict of box coders. loss_cls (dict): Config of classification loss. loss_bbox (dict): Config of localization loss. loss_dir (dict): Config of direction classifier loss. """ def __init__(self, num_classes, in_channels, train_cfg, test_cfg, feat_channels=256, use_direction_classifier=True, anchor_generator=dict( type='Anchor3DRangeGenerator', range=[0, -39.68, -1.78, 69.12, 39.68, -1.78], strides=[2], sizes=[[1.6, 3.9, 1.56]], rotations=[0, 1.57], custom_values=[], reshape_out=False), assigner_per_size=False, assign_per_class=False, diff_rad_by_sin=True, dir_offset=0, dir_limit_offset=1, bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), loss_bbox=dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)): super().__init__(num_classes, in_channels, train_cfg, test_cfg, feat_channels, use_direction_classifier, anchor_generator, assigner_per_size, assign_per_class, diff_rad_by_sin, dir_offset, dir_limit_offset, bbox_coder, loss_cls, loss_bbox, loss_dir) def get_bboxes_single(self, cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors, input_meta, cfg, rescale=False): assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) mlvl_bboxes = [] mlvl_max_scores = [] mlvl_label_pred = [] mlvl_dir_scores = [] mlvl_cls_score = [] for cls_score, bbox_pred, dir_cls_pred, anchors in zip( cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert cls_score.size()[-2:] == dir_cls_pred.size()[-2:] dir_cls_pred = dir_cls_pred.permute(1, 2, 0).reshape(-1, 2) dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1] cls_score = cls_score.permute(1, 2, 0).reshape(-1, self.num_classes) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1) bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, self.box_code_size) nms_pre = cfg.get('nms_pre', -1) if self.use_sigmoid_cls: max_scores, pred_labels = scores.max(dim=1) else: max_scores, pred_labels = scores[:, :-1].max(dim=1) # get topk if nms_pre > 0 and scores.shape[0] > nms_pre: topk_scores, topk_inds = max_scores.topk(nms_pre) anchors = anchors[topk_inds, :] bbox_pred = bbox_pred[topk_inds, :] max_scores = topk_scores cls_score = scores[topk_inds, :] dir_cls_score = dir_cls_score[topk_inds] pred_labels = pred_labels[topk_inds] bboxes = self.bbox_coder.decode(anchors, bbox_pred) mlvl_bboxes.append(bboxes) mlvl_max_scores.append(max_scores) mlvl_cls_score.append(cls_score) mlvl_label_pred.append(pred_labels) mlvl_dir_scores.append(dir_cls_score) mlvl_bboxes = torch.cat(mlvl_bboxes) mlvl_bboxes_for_nms = boxes3d_to_bev_torch_lidar(mlvl_bboxes) mlvl_max_scores = torch.cat(mlvl_max_scores) mlvl_label_pred = torch.cat(mlvl_label_pred) mlvl_dir_scores = torch.cat(mlvl_dir_scores) # shape [k, num_class] before sigmoid # PartA2 need to keep raw classification score # becase the bbox head in the second stage does not have # classification branch, # roi head need this score as classification score mlvl_cls_score = torch.cat(mlvl_cls_score) score_thr = cfg.get('score_thr', 0) result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_max_scores, mlvl_label_pred, mlvl_cls_score, mlvl_dir_scores, score_thr, cfg.nms_post, cfg, input_meta) return result def class_agnostic_nms(self, mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_max_scores, mlvl_label_pred, mlvl_cls_score, mlvl_dir_scores, score_thr, max_num, cfg, input_meta): bboxes = [] scores = [] labels = [] dir_scores = [] cls_scores = [] score_thr_inds = mlvl_max_scores > score_thr _scores = mlvl_max_scores[score_thr_inds] _bboxes_for_nms = mlvl_bboxes_for_nms[score_thr_inds, :] if cfg.use_rotate_nms: nms_func = nms_gpu else: nms_func = nms_normal_gpu selected = nms_func(_bboxes_for_nms, _scores, cfg.nms_thr) _mlvl_bboxes = mlvl_bboxes[score_thr_inds, :] _mlvl_dir_scores = mlvl_dir_scores[score_thr_inds] _mlvl_label_pred = mlvl_label_pred[score_thr_inds] _mlvl_cls_score = mlvl_cls_score[score_thr_inds] if len(selected) > 0: bboxes.append(_mlvl_bboxes[selected]) scores.append(_scores[selected]) labels.append(_mlvl_label_pred[selected]) cls_scores.append(_mlvl_cls_score[selected]) dir_scores.append(_mlvl_dir_scores[selected]) dir_rot = box_torch_ops.limit_period( bboxes[-1][..., 6] - self.dir_offset, self.dir_limit_offset, np.pi) bboxes[-1][..., 6] = ( dir_rot + self.dir_offset + np.pi * dir_scores[-1].to(bboxes[-1].dtype)) if bboxes: bboxes = torch.cat(bboxes, dim=0) scores = torch.cat(scores, dim=0) cls_scores = torch.cat(cls_scores, dim=0) labels = torch.cat(labels, dim=0) dir_scores = torch.cat(dir_scores, dim=0) if bboxes.shape[0] > max_num: _, inds = scores.sort(descending=True) inds = inds[:max_num] bboxes = bboxes[inds, :] labels = labels[inds] scores = scores[inds] cls_scores = cls_scores[inds] bboxes = input_meta['box_type_3d']( bboxes, box_dim=self.box_code_size) return dict( boxes_3d=bboxes, scores_3d=scores, labels_3d=labels, cls_preds=cls_scores # raw scores [max_num, cls_num] ) else: return dict( boxes_3d=input_meta['box_type_3d']( mlvl_bboxes.new_zeros([0, self.box_code_size]), box_dim=self.box_code_size), scores_3d=mlvl_bboxes.new_zeros([0]), labels_3d=mlvl_bboxes.new_zeros([0]), cls_preds=mlvl_bboxes.new_zeros([0, mlvl_cls_score.shape[-1]]))