# ------------------------------------------------------------------------ # Copyright (c) 2022 megvii-model. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from DETR3D (https://github.com/WangYueFt/detr3d) # Copyright (c) 2021 Wang, Yue # ------------------------------------------------------------------------ # Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) # Copyright (c) OpenMMLab. All rights reserved. # ------------------------------------------------------------------------ import numpy as np import torch from mmengine.structures import InstanceData import mmdet3d from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from mmdet3d.registry import MODELS from mmdet3d.structures.bbox_3d import LiDARInstance3DBoxes, limit_period from mmdet3d.structures.ops import bbox3d2result from .grid_mask import GridMask @MODELS.register_module() class PETR(MVXTwoStageDetector): """PETR.""" def __init__(self, use_grid_mask=False, pts_voxel_layer=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, init_cfg=None, data_preprocessor=None, **kwargs): super(PETR, self).__init__(pts_voxel_layer, pts_middle_encoder, pts_fusion_layer, img_backbone, pts_backbone, img_neck, pts_neck, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, init_cfg, data_preprocessor) self.grid_mask = GridMask( True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7) self.use_grid_mask = use_grid_mask def extract_img_feat(self, img, img_metas): """Extract features of images.""" if isinstance(img, list): img = torch.stack(img, dim=0) B = img.size(0) if img is not None: input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) if img.dim() == 5: if img.size(0) == 1 and img.size(1) != 1: img.squeeze_() else: B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) if self.use_grid_mask: img = self.grid_mask(img) img_feats = self.img_backbone(img) if isinstance(img_feats, dict): img_feats = list(img_feats.values()) else: return None if self.with_img_neck: img_feats = self.img_neck(img_feats) img_feats_reshaped = [] for img_feat in img_feats: BN, C, H, W = img_feat.size() img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) return img_feats_reshaped # @auto_fp16(apply_to=('img'), out_fp32=True) def extract_feat(self, img, img_metas): """Extract features from images and points.""" img_feats = self.extract_img_feat(img, img_metas) return img_feats def forward_pts_train(self, pts_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore=None): """Forward function for point cloud branch. Args: pts_feats (list[torch.Tensor]): Features of point cloud branch gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes for each sample. gt_labels_3d (list[torch.Tensor]): Ground truth labels for boxes of each sampole img_metas (list[dict]): Meta information of samples. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth boxes to be ignored. Defaults to None. Returns: dict: Losses of each branch. """ outs = self.pts_bbox_head(pts_feats, img_metas) loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs] losses = self.pts_bbox_head.loss_by_feat(*loss_inputs) return losses def _forward(self, mode='loss', **kwargs): """Calls either forward_train or forward_test depending on whether return_loss=True. Note this setting will change the expected inputs. When `return_loss=True`, img and img_metas are single-nested (i.e. torch.Tensor and list[dict]), and when `resturn_loss=False`, img and img_metas should be double nested (i.e. list[torch.Tensor], list[list[dict]]), with the outer list indicating test time augmentations. """ raise NotImplementedError('tensor mode is yet to add') def loss(self, inputs=None, data_samples=None, mode=None, points=None, img_metas=None, gt_bboxes_3d=None, gt_labels_3d=None, gt_labels=None, gt_bboxes=None, img=None, proposals=None, gt_bboxes_ignore=None, img_depth=None, img_mask=None): """Forward training function. Args: points (list[torch.Tensor], optional): Points of each sample. Defaults to None. img_metas (list[dict], optional): Meta information of each sample. Defaults to None. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): Ground truth 3D boxes. Defaults to None. gt_labels_3d (list[torch.Tensor], optional): Ground truth labels of 3D boxes. Defaults to None. gt_labels (list[torch.Tensor], optional): Ground truth labels of 2D boxes in images. Defaults to None. gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in images. Defaults to None. img (torch.Tensor optional): Images of each sample with shape (N, C, H, W). Defaults to None. proposals ([list[torch.Tensor], optional): Predicted proposals used for training Fast RCNN. Defaults to None. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth 2D boxes in images to be ignored. Defaults to None. Returns: dict: Losses of different branches. """ img = inputs['imgs'] batch_img_metas = [ds.metainfo for ds in data_samples] batch_gt_instances_3d = [ds.gt_instances_3d for ds in data_samples] gt_bboxes_3d = [gt.bboxes_3d for gt in batch_gt_instances_3d] gt_labels_3d = [gt.labels_3d for gt in batch_gt_instances_3d] gt_bboxes_ignore = None gt_bboxes_3d = self.LidarBox3dVersionTransfrom(gt_bboxes_3d) batch_img_metas = self.add_lidar2img(img, batch_img_metas) img_feats = self.extract_feat(img=img, img_metas=batch_img_metas) losses = dict() losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, batch_img_metas, gt_bboxes_ignore) losses.update(losses_pts) return losses def predict(self, inputs=None, data_samples=None, mode=None, **kwargs): img = inputs['imgs'] batch_img_metas = [ds.metainfo for ds in data_samples] for var, name in [(batch_img_metas, 'img_metas')]: if not isinstance(var, list): raise TypeError('{} must be a list, but got {}'.format( name, type(var))) img = [img] if img is None else img batch_img_metas = self.add_lidar2img(img, batch_img_metas) results_list_3d = self.simple_test(batch_img_metas, img, **kwargs) for i, data_sample in enumerate(data_samples): results_list_3d_i = InstanceData( metainfo=results_list_3d[i]['pts_bbox']) data_sample.pred_instances_3d = results_list_3d_i data_sample.pred_instances = InstanceData() return data_samples def simple_test_pts(self, x, img_metas, rescale=False): """Test function of point cloud branch.""" outs = self.pts_bbox_head(x, img_metas) bbox_list = self.pts_bbox_head.get_bboxes( outs, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results def simple_test(self, img_metas, img=None, rescale=False): """Test function without augmentaiton.""" img_feats = self.extract_feat(img=img, img_metas=img_metas) bbox_list = [dict() for i in range(len(img_metas))] bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox return bbox_list def aug_test_pts(self, feats, img_metas, rescale=False): feats_list = [] for j in range(len(feats[0])): feats_list_level = [] for i in range(len(feats)): feats_list_level.append(feats[i][j]) feats_list.append(torch.stack(feats_list_level, -1).mean(-1)) outs = self.pts_bbox_head(feats_list, img_metas) bbox_list = self.pts_bbox_head.get_bboxes( outs, img_metas, rescale=rescale) bbox_results = [ bbox3d2result(bboxes, scores, labels) for bboxes, scores, labels in bbox_list ] return bbox_results def aug_test(self, img_metas, imgs=None, rescale=False): """Test function with augmentaiton.""" img_feats = self.extract_feats(img_metas, imgs) img_metas = img_metas[0] bbox_list = [dict() for i in range(len(img_metas))] bbox_pts = self.aug_test_pts(img_feats, img_metas, rescale) for result_dict, pts_bbox in zip(bbox_list, bbox_pts): result_dict['pts_bbox'] = pts_bbox return bbox_list # may need speed-up def add_lidar2img(self, img, batch_input_metas): """add 'lidar2img' transformation matrix into batch_input_metas. Args: batch_input_metas (list[dict]): Meta information of multiple inputs in a batch. Returns: batch_input_metas (list[dict]): Meta info with lidar2img added """ lidar2img_rts = [] for meta in batch_input_metas: # obtain lidar to image transformation matrix for i in range(len(meta['cam2img'])): lidar2cam_rt = torch.tensor(meta['lidar2cam'][i]).double() intrinsic = torch.tensor(meta['cam2img'][i]).double() viewpad = torch.eye(4).double() viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic lidar2img_rt = (viewpad @ lidar2cam_rt) # The extrinsics mean the transformation from lidar to camera. # If anyone want to use the extrinsics as sensor to lidar, # please use np.linalg.inv(lidar2cam_rt.T) # and modify the ResizeCropFlipImage # and LoadMultiViewImageFromMultiSweepsFiles. lidar2img_rts.append(lidar2img_rt) meta['lidar2img'] = lidar2img_rts meta['img_shape'] = [i.shape for i in img[0]] return batch_input_metas def LidarBox3dVersionTransfrom(self, gt_bboxes_3d): if int(mmdet3d.__version__[0]) >= 1: # Begin hack adaptation to mmdet3d v1.0 #### gt_bboxes_3d = gt_bboxes_3d[0].tensor gt_bboxes_3d[:, [3, 4]] = gt_bboxes_3d[:, [4, 3]] gt_bboxes_3d[:, 6] = -gt_bboxes_3d[:, 6] - np.pi / 2 gt_bboxes_3d[:, 6] = limit_period( gt_bboxes_3d[:, 6], period=np.pi * 2) gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d, box_dim=9) gt_bboxes_3d = [gt_bboxes_3d] return gt_bboxes_3d