parta2.py 3.97 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, List, Optional

wuyuefeng's avatar
wuyuefeng committed
4
import torch
5
from mmcv.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
6
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
7

8
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
9
from .two_stage import TwoStage3DDetector
wuyuefeng's avatar
wuyuefeng committed
10
11


12
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
13
class PartA2(TwoStage3DDetector):
14
    r"""Part-A2 detector.
zhangwenwei's avatar
zhangwenwei committed
15
16
17

    Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
    """
wuyuefeng's avatar
wuyuefeng committed
18
19

    def __init__(self,
20
21
22
23
24
25
26
27
28
29
30
                 voxel_layer: dict,
                 voxel_encoder: dict,
                 middle_encoder: dict,
                 backbone: dict,
                 neck: dict = None,
                 rpn_head: dict = None,
                 roi_head: dict = None,
                 train_cfg: dict = None,
                 test_cfg: dict = None,
                 init_cfg: dict = None,
                 data_preprocessor: Optional[dict] = None):
wuyuefeng's avatar
wuyuefeng committed
31
32
33
34
35
36
37
        super(PartA2, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
38
39
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor)
wuyuefeng's avatar
wuyuefeng committed
40
        self.voxel_layer = Voxelization(**voxel_layer)
41
42
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
wuyuefeng's avatar
wuyuefeng committed
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def extract_feat(self, batch_inputs_dict: Dict) -> Dict:
        """Directly extract features from the backbone+neck.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (torch.Tensor, optional): Image of each sample.

        Returns:
            tuple[Tensor] | dict:  For outside 3D object detection, we
                typically obtain a tuple of features from the backbone + neck,
                and for inside 3D object detection, usually a dict containing
                features will be obtained.
        """
        points = batch_inputs_dict['points']
wuyuefeng's avatar
wuyuefeng committed
61
62
63
64
65
66
67
        voxel_dict = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        feats_dict = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                         batch_size)
wuyuefeng's avatar
wuyuefeng committed
68
69
70
71
        x = self.backbone(feats_dict['spatial_features'])
        if self.with_neck:
            neck_feats = self.neck(x)
            feats_dict.update({'neck_feats': neck_feats})
72
73
        feats_dict['voxels_dict'] = voxel_dict
        return feats_dict
wuyuefeng's avatar
wuyuefeng committed
74
75

    @torch.no_grad()
76
    def voxelize(self, points: List[torch.Tensor]) -> Dict:
77
        """Apply hard voxelization to points."""
wuyuefeng's avatar
wuyuefeng committed
78
        voxels, coors, num_points, voxel_centers = [], [], [], []
wuyuefeng's avatar
wuyuefeng committed
79
80
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
wuyuefeng's avatar
wuyuefeng committed
81
82
83
84
            res_voxel_centers = (
                res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                    self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                        self.voxel_layer.point_cloud_range[0:3])
wuyuefeng's avatar
wuyuefeng committed
85
86
87
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
wuyuefeng's avatar
wuyuefeng committed
88
89
            voxel_centers.append(res_voxel_centers)

wuyuefeng's avatar
wuyuefeng committed
90
91
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
wuyuefeng's avatar
wuyuefeng committed
92
        voxel_centers = torch.cat(voxel_centers, dim=0)
wuyuefeng's avatar
wuyuefeng committed
93
94
95
96
97
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
wuyuefeng's avatar
wuyuefeng committed
98
99
100
101
102
103
104

        voxel_dict = dict(
            voxels=voxels,
            num_points=num_points,
            coors=coors_batch,
            voxel_centers=voxel_centers)
        return voxel_dict