dynamic_voxelnet.py 2.67 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Tuple

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from mmcv.runner import force_fp32
6
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
7
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
8

9
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
10
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
zhangwenwei's avatar
zhangwenwei committed
11
12
13
from .voxelnet import VoxelNet


14
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
15
class DynamicVoxelNet(VoxelNet):
zhangwenwei's avatar
zhangwenwei committed
16
17
    r"""VoxelNet using `dynamic voxelization <https://arxiv.org/abs/1910.06528>`_.
    """
zhangwenwei's avatar
zhangwenwei committed
18
19

    def __init__(self,
20
21
22
23
24
25
26
27
28
29
30
                 voxel_layer: ConfigType,
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
35
36
37
38
            voxel_layer=voxel_layer,
            voxel_encoder=voxel_encoder,
            middle_encoder=middle_encoder,
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
39
            data_preprocessor=data_preprocessor,
40
            init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
41
42

    @torch.no_grad()
43
    @force_fp32()
44
    def voxelize(self, points: List[torch.Tensor]) -> tuple:
zhangwenwei's avatar
zhangwenwei committed
45
46
47
        """Apply dynamic voxelization to points.

        Args:
48
            points (list[Tensor]): Points of each sample.
zhangwenwei's avatar
zhangwenwei committed
49
50

        Returns:
51
            tuple[Tensor]: Concatenated points and coordinates.
zhangwenwei's avatar
zhangwenwei committed
52
        """
zhangwenwei's avatar
zhangwenwei committed
53
54
55
56
57
58
59
60
61
62
63
64
        coors = []
        # dynamic voxelization only provide a coors mapping
        for res in points:
            res_coors = self.voxel_layer(res)
            coors.append(res_coors)
        points = torch.cat(points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return points, coors_batch
65
66
67
68
69
70
71
72
73
74
75
76
77

    def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]:
        """Extract features from points."""
        # TODO: Remove voxelization to datapreprocessor
        points = batch_inputs_dict['points']
        voxels, coors = self.voxelize(points)
        voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x