# Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple from torch import Tensor from mmdet3d.registry import MODELS from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig from .voxelnet import VoxelNet @MODELS.register_module() class DynamicVoxelNet(VoxelNet): r"""VoxelNet using `dynamic voxelization `_. """ def __init__(self, voxel_encoder: ConfigType, middle_encoder: ConfigType, backbone: ConfigType, neck: OptConfigType = None, bbox_head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( voxel_encoder=voxel_encoder, middle_encoder=middle_encoder, backbone=backbone, neck=neck, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg) def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]: """Extract features from points.""" voxel_dict = batch_inputs_dict['voxels'] voxel_features, feature_coors = self.voxel_encoder( voxel_dict['voxels'], voxel_dict['coors']) batch_size = voxel_dict['coors'][-1, 0].item() + 1 x = self.middle_encoder(voxel_features, feature_coors, batch_size) x = self.backbone(x) if self.with_neck: x = self.neck(x) return x