# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch from mmcv.ops import Voxelization from mmcv.runner import force_fp32 from torch import Tensor from torch.nn import functional as F from mmdet3d.registry import MODELS from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig from .single_stage import SingleStage3DDetector @MODELS.register_module() class VoxelNet(SingleStage3DDetector): r"""`VoxelNet `_ for 3D detection.""" def __init__(self, voxel_layer: ConfigType, voxel_encoder: ConfigType, middle_encoder: ConfigType, backbone: ConfigType, neck: OptConfigType = None, bbox_head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( backbone=backbone, neck=neck, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.voxel_layer = Voxelization(**voxel_layer) self.voxel_encoder = MODELS.build(voxel_encoder) self.middle_encoder = MODELS.build(middle_encoder) @torch.no_grad() @force_fp32() def voxelize(self, points: List[torch.Tensor]) -> tuple: """Apply hard voxelization to points.""" voxels, coors, num_points = [], [], [] for res in points: res_voxels, res_coors, res_num_points = self.voxel_layer(res) voxels.append(res_voxels) coors.append(res_coors) num_points.append(res_num_points) voxels = torch.cat(voxels, dim=0) num_points = torch.cat(num_points, dim=0) coors_batch = [] for i, coor in enumerate(coors): coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) coors_batch.append(coor_pad) coors_batch = torch.cat(coors_batch, dim=0) return voxels, num_points, coors_batch def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]: """Extract features from points.""" # TODO: Remove voxelization to datapreprocessor points = batch_inputs_dict['points'] voxels, num_points, coors = self.voxelize(points) voxel_features = self.voxel_encoder(voxels, num_points, coors) batch_size = coors[-1, 0].item() + 1 x = self.middle_encoder(voxel_features, coors, batch_size) x = self.backbone(x) if self.with_neck: x = self.neck(x) return x