dynamic_voxelnet.py 1.67 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Tuple
3
4

from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
5

6
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
zhangwenwei's avatar
zhangwenwei committed
8
9
10
from .voxelnet import VoxelNet


11
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class DynamicVoxelNet(VoxelNet):
13
14
    r"""VoxelNet using `dynamic voxelization
    <https://arxiv.org/abs/1910.06528>`_.
zhangwenwei's avatar
zhangwenwei committed
15
    """
zhangwenwei's avatar
zhangwenwei committed
16
17

    def __init__(self,
18
19
20
21
22
23
24
25
26
27
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
zhangwenwei's avatar
zhangwenwei committed
28
29
30
31
32
33
34
            voxel_encoder=voxel_encoder,
            middle_encoder=middle_encoder,
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
35
            data_preprocessor=data_preprocessor,
36
            init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
37

38
39
    def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]:
        """Extract features from points."""
40
41
42
43
        voxel_dict = batch_inputs_dict['voxels']
        voxel_features, feature_coors = self.voxel_encoder(
            voxel_dict['voxels'], voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
44
45
46
47
48
        x = self.middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x