dynamic_voxelnet.py 2.1 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
3
4
5
6
7
8
9

from mmdet.models import DETECTORS
from .voxelnet import VoxelNet


@DETECTORS.register_module()
class DynamicVoxelNet(VoxelNet):
zhangwenwei's avatar
zhangwenwei committed
10
11
    r"""VoxelNet using `dynamic voxelization <https://arxiv.org/abs/1910.06528>`_.
    """
zhangwenwei's avatar
zhangwenwei committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(DynamicVoxelNet, self).__init__(
            voxel_layer=voxel_layer,
            voxel_encoder=voxel_encoder,
            middle_encoder=middle_encoder,
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )

    def extract_feat(self, points, img_metas):
zhangwenwei's avatar
zhangwenwei committed
36
        """Extract features from points."""
zhangwenwei's avatar
zhangwenwei committed
37
38
39
40
41
42
43
44
45
46
47
        voxels, coors = self.voxelize(points)
        voxel_features, feature_coors = self.voxel_encoder(voxels, coors)
        batch_size = coors[-1, 0].item() + 1
        x = self.middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)
        return x

    @torch.no_grad()
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
48
49
50
51
52
53
54
55
        """Apply dynamic voxelization to points.

        Args:
            points (list[torch.Tensor]): Points of each sample.

        Returns:
            tuple[torch.Tensor]: Concatenated points and coordinates.
        """
zhangwenwei's avatar
zhangwenwei committed
56
57
58
59
60
61
62
63
64
65
66
67
        coors = []
        # dynamic voxelization only provide a coors mapping
        for res in points:
            res_coors = self.voxel_layer(res)
            coors.append(res_coors)
        points = torch.cat(points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return points, coors_batch