voxelnet.py 2.73 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import List, Tuple
3

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from mmcv.ops import Voxelization
6
from mmcv.runner import force_fp32
7
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
8
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
9

10
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
11
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
zhangwenwei's avatar
zhangwenwei committed
12
from .single_stage import SingleStage3DDetector
zhangwenwei's avatar
zhangwenwei committed
13
14


15
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
16
class VoxelNet(SingleStage3DDetector):
17
    r"""`VoxelNet <https://arxiv.org/abs/1711.06396>`_ for 3D detection."""
zhangwenwei's avatar
zhangwenwei committed
18
19

    def __init__(self,
20
21
22
23
24
25
26
27
28
29
30
                 voxel_layer: ConfigType,
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
35
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
36
37
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
38
        self.voxel_layer = Voxelization(**voxel_layer)
39
40
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
zhangwenwei's avatar
zhangwenwei committed
41
42

    @torch.no_grad()
43
    @force_fp32()
44
    def voxelize(self, points: List[torch.Tensor]) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
45
        """Apply hard voxelization to points."""
zhangwenwei's avatar
zhangwenwei committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        voxels, coors, num_points = [], [], []
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return voxels, num_points, coors_batch

61
62
63
64
65
66
67
68
69
70
71
72
    def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]:
        """Extract features from points."""
        # TODO: Remove voxelization to datapreprocessor
        points = batch_inputs_dict['points']
        voxels, num_points, coors = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxels, num_points, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x