mvx_faster_rcnn.py 2.61 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
3
from typing import List, Optional, Sequence

zhangwenwei's avatar
zhangwenwei committed
4
import torch
jshilong's avatar
jshilong committed
5
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
6
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
7

8
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
9
10
11
from .mvx_two_stage import MVXTwoStageDetector


12
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
13
class MVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
14
    """Multi-modality VoxelNet using Faster R-CNN."""
zhangwenwei's avatar
zhangwenwei committed
15
16
17
18
19

    def __init__(self, **kwargs):
        super(MVXFasterRCNN, self).__init__(**kwargs)


20
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
21
class DynamicMVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
22
    """Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization."""
zhangwenwei's avatar
zhangwenwei committed
23
24
25
26
27
28

    def __init__(self, **kwargs):
        super(DynamicMVXFasterRCNN, self).__init__(**kwargs)

    @torch.no_grad()
    def voxelize(self, points):
zhangwenwei's avatar
zhangwenwei committed
29
30
31
32
33
34
35
36
        """Apply dynamic voxelization to points.

        Args:
            points (list[torch.Tensor]): Points of each sample.

        Returns:
            tuple[torch.Tensor]: Concatenated points and coordinates.
        """
zhangwenwei's avatar
zhangwenwei committed
37
38
39
40
41
42
43
44
45
46
47
48
49
        coors = []
        # dynamic voxelization only provide a coors mapping
        for res in points:
            res_coors = self.pts_voxel_layer(res)
            coors.append(res_coors)
        points = torch.cat(points, dim=0)
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
        return points, coors_batch

jshilong's avatar
jshilong committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def extract_pts_feat(
            self,
            points: List[Tensor],
            img_feats: Optional[Sequence[Tensor]] = None,
            batch_input_metas: Optional[List[dict]] = None
    ) -> Sequence[Tensor]:
        """Extract features of points.

        Args:
            points (List[tensor]):  Point cloud of multiple inputs.
            img_feats (list[Tensor], tuple[tensor], optional): Features from
                image backbone.
            batch_input_metas (list[dict], optional): The meta information
                of multiple samples. Defaults to True.

        Returns:
            Sequence[tensor]: points features of multiple inputs
            from backbone or neck.
        """
zhangwenwei's avatar
zhangwenwei committed
69
70
71
72
        if not self.with_pts_bbox:
            return None
        voxels, coors = self.voxelize(points)
        voxel_features, feature_coors = self.pts_voxel_encoder(
jshilong's avatar
jshilong committed
73
            voxels, coors, points, img_feats, batch_input_metas)
zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
79
        batch_size = coors[-1, 0] + 1
        x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.pts_backbone(x)
        if self.with_pts_neck:
            x = self.pts_neck(x)
        return x