# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple, Union import torch from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures import Det3DDataSample from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.utils import ConfigType, InstanceList, OptConfigType from mmdet.models.detectors import BaseDetector @MODELS.register_module() class ImVoxelNet(BaseDetector): r"""`ImVoxelNet `_. Args: backbone (:obj:`ConfigDict` or dict): The backbone config. neck (:obj:`ConfigDict` or dict): The neck config. neck_3d (:obj:`ConfigDict` or dict): The 3D neck config. bbox_head (:obj:`ConfigDict` or dict): The bbox head config. n_voxels (list): Number of voxels along x, y, z axis. anchor_generator (:obj:`ConfigDict` or dict): The anchor generator config. train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of training hyper-parameters. Defaults to None. test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test hyper-parameters. Defaults to None. data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`BaseDataPreprocessor`. it usually includes, ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. init_cfg (:obj:`ConfigDict` or dict, optional): The initialization config. Defaults to None. """ def __init__(self, backbone: ConfigType, neck: ConfigType, neck_3d: ConfigType, bbox_head: ConfigType, n_voxels: List, anchor_generator: ConfigType, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptConfigType = None): super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.backbone = MODELS.build(backbone) self.neck = MODELS.build(neck) self.neck_3d = MODELS.build(neck_3d) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) self.bbox_head = MODELS.build(bbox_head) self.n_voxels = n_voxels self.anchor_generator = TASK_UTILS.build(anchor_generator) self.train_cfg = train_cfg self.test_cfg = test_cfg def convert_to_datasample(self, results_list: InstanceList) -> SampleList: """Convert results list to `Det3DDataSample`. Args: results_list (list[:obj:`InstanceData`]): 3D Detection results of each image. Returns: list[:obj:`Det3DDataSample`]: 3D Detection results of the input images. Each Det3DDataSample usually contain 'pred_instances_3d'. And the ``pred_instances_3d`` usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instance, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (Tensor): Contains a tensor with shape (num_instances, C) where C >=7. """ out_results_list = [] for i in range(len(results_list)): result = Det3DDataSample() result.pred_instances_3d = results_list[i] out_results_list.append(result) return out_results_list def extract_feat(self, batch_inputs_dict: dict, batch_data_samples: SampleList): """Extract 3d features from the backbone -> fpn -> 3d projection. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: torch.Tensor: of shape (N, C_out, N_x, N_y, N_z) """ img = batch_inputs_dict['imgs'] batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] x = self.backbone(img) x = self.neck(x)[0] points = self.anchor_generator.grid_anchors( [self.n_voxels[::-1]], device=img.device)[0][:, :3] volumes = [] for feature, img_meta in zip(x, batch_img_metas): img_scale_factor = ( points.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( points.new_tensor(img_meta['img_crop_offset']) if 'img_crop_offset' in img_meta.keys() else 0) lidar2img = points.new_tensor(img_meta['lidar2img']) volume = point_sample( img_meta, img_features=feature[None, ...], points=points, proj_mat=lidar2img, coord_type='LIDAR', img_scale_factor=img_scale_factor, img_crop_offset=img_crop_offset, img_flip=img_flip, img_pad_shape=img.shape[-2:], img_shape=img_meta['img_shape'][:2], aligned=False) volumes.append( volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0)) x = torch.stack(volumes) x = self.neck_3d(x) return x def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, **kwargs) -> Union[dict, list]: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ x = self.extract_feat(batch_inputs_dict, batch_data_samples) losses = self.bbox_head.loss(x, batch_data_samples, **kwargs) return losses def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList, **kwargs) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data Samples. It usually includes information such as `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. Returns: list[:obj:`Det3DDataSample`]: Detection results of the input images. Each Det3DDataSample usually contain 'pred_instances_3d'. And the ``pred_instances_3d`` usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instance, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (Tensor): Contains a tensor with shape (num_instances, C) where C >=7. """ x = self.extract_feat(batch_inputs_dict, batch_data_samples) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) predictions = self.convert_to_datasample(results_list) return predictions def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList, *args, **kwargs) -> Tuple[List[torch.Tensor]]: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data Samples. It usually includes information such as `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. Returns: tuple[list]: A tuple of features from ``bbox_head`` forward. """ x = self.extract_feat(batch_inputs_dict, batch_data_samples) results = self.bbox_head.forward(x) return results