voxelnet.py 1.83 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Tuple
3

4
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
5

6
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
zhangwenwei's avatar
zhangwenwei committed
8
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
9
10


11
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class VoxelNet(SingleStage3DDetector):
13
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
14
15

    def __init__(self,
16
17
18
19
20
21
22
23
24
25
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
zhangwenwei's avatar
zhangwenwei committed
26
27
28
29
30
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
31
32
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)
33
34
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
zhangwenwei's avatar
zhangwenwei committed
35

36
37
    def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]:
        """Extract features from points."""
38
39
40
41
42
43
44
        voxel_dict = batch_inputs_dict['voxels']
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                batch_size)
45
46
47
48
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x