parta2.py 5.69 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
3
from mmcv.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
4
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
5
6

from .. import builder
7
from ..builder import DETECTORS
zhangwenwei's avatar
zhangwenwei committed
8
from .two_stage import TwoStage3DDetector
wuyuefeng's avatar
wuyuefeng committed
9
10


11
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class PartA2(TwoStage3DDetector):
13
    r"""Part-A2 detector.
zhangwenwei's avatar
zhangwenwei committed
14
15
16

    Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
    """
wuyuefeng's avatar
wuyuefeng committed
17
18
19
20
21
22
23
24
25
26
27

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
28
29
                 pretrained=None,
                 init_cfg=None):
wuyuefeng's avatar
wuyuefeng committed
30
31
32
33
34
35
36
37
        super(PartA2, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
38
            init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
39
40
41
42
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
43
    def extract_feat(self, points, img_metas):
44
        """Extract features from points."""
wuyuefeng's avatar
wuyuefeng committed
45
46
47
48
49
50
51
        voxel_dict = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        feats_dict = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                         batch_size)
wuyuefeng's avatar
wuyuefeng committed
52
53
54
55
56
57
58
59
        x = self.backbone(feats_dict['spatial_features'])
        if self.with_neck:
            neck_feats = self.neck(x)
            feats_dict.update({'neck_feats': neck_feats})
        return feats_dict, voxel_dict

    @torch.no_grad()
    def voxelize(self, points):
60
        """Apply hard voxelization to points."""
wuyuefeng's avatar
wuyuefeng committed
61
        voxels, coors, num_points, voxel_centers = [], [], [], []
wuyuefeng's avatar
wuyuefeng committed
62
63
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
wuyuefeng's avatar
wuyuefeng committed
64
65
66
67
            res_voxel_centers = (
                res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                    self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                        self.voxel_layer.point_cloud_range[0:3])
wuyuefeng's avatar
wuyuefeng committed
68
69
70
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
wuyuefeng's avatar
wuyuefeng committed
71
72
            voxel_centers.append(res_voxel_centers)

wuyuefeng's avatar
wuyuefeng committed
73
74
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
wuyuefeng's avatar
wuyuefeng committed
75
        voxel_centers = torch.cat(voxel_centers, dim=0)
wuyuefeng's avatar
wuyuefeng committed
76
77
78
79
80
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
wuyuefeng's avatar
wuyuefeng committed
81
82
83
84
85
86
87

        voxel_dict = dict(
            voxels=voxels,
            num_points=num_points,
            coors=coors_batch,
            voxel_centers=voxel_centers)
        return voxel_dict
wuyuefeng's avatar
wuyuefeng committed
88
89
90

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
91
                      img_metas,
wuyuefeng's avatar
wuyuefeng committed
92
93
94
95
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None,
                      proposals=None):
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        """Training forward function.

        Args:
            points (list[torch.Tensor]): Point cloud of each sample.
            img_metas (list[dict]): Meta information of each sample
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.

        Returns:
            dict: Losses of each branch.
        """
zhangwenwei's avatar
zhangwenwei committed
111
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
112
113
114
115
116

        losses = dict()

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
zhangwenwei's avatar
zhangwenwei committed
117
118
            rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d,
                                          img_metas)
wuyuefeng's avatar
wuyuefeng committed
119
120
121
122
123
124
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
zhangwenwei's avatar
zhangwenwei committed
125
            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
126
127
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
wuyuefeng's avatar
wuyuefeng committed
128
129
130
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict,
zhangwenwei's avatar
zhangwenwei committed
131
                                                 img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
132
133
134
                                                 gt_bboxes_3d, gt_labels_3d)

        losses.update(roi_losses)
wuyuefeng's avatar
wuyuefeng committed
135
136
137

        return losses

zhangwenwei's avatar
zhangwenwei committed
138
    def simple_test(self, points, img_metas, proposals=None, rescale=False):
139
        """Test function without augmentaiton."""
zhangwenwei's avatar
zhangwenwei committed
140
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
141
142
143
144

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
            proposal_cfg = self.test_cfg.rpn
zhangwenwei's avatar
zhangwenwei committed
145
            bbox_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
146
            proposal_list = self.rpn_head.get_bboxes(*bbox_inputs)
wuyuefeng's avatar
wuyuefeng committed
147
148
149
        else:
            proposal_list = proposals

zhangwenwei's avatar
zhangwenwei committed
150
        return self.roi_head.simple_test(feats_dict, voxels_dict, img_metas,
wuyuefeng's avatar
wuyuefeng committed
151
                                         proposal_list)