parta2.py 4.86 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
import torch
zhangwenwei's avatar
zhangwenwei committed
2
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
3
4

from mmdet3d.ops import Voxelization
zhangwenwei's avatar
zhangwenwei committed
5
from mmdet.models import DETECTORS
wuyuefeng's avatar
wuyuefeng committed
6
from .. import builder
zhangwenwei's avatar
zhangwenwei committed
7
from .two_stage import TwoStage3DDetector
wuyuefeng's avatar
wuyuefeng committed
8
9


10
@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
11
class PartA2(TwoStage3DDetector):
zhangwenwei's avatar
zhangwenwei committed
12
    """Part-A2 detector.
zhangwenwei's avatar
zhangwenwei committed
13
14
15

    Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
    """
wuyuefeng's avatar
wuyuefeng committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(PartA2, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
41
    def extract_feat(self, points, img_metas):
wuyuefeng's avatar
wuyuefeng committed
42
43
44
45
46
47
48
        voxel_dict = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        feats_dict = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                         batch_size)
wuyuefeng's avatar
wuyuefeng committed
49
50
51
52
53
54
55
56
        x = self.backbone(feats_dict['spatial_features'])
        if self.with_neck:
            neck_feats = self.neck(x)
            feats_dict.update({'neck_feats': neck_feats})
        return feats_dict, voxel_dict

    @torch.no_grad()
    def voxelize(self, points):
wuyuefeng's avatar
wuyuefeng committed
57
        voxels, coors, num_points, voxel_centers = [], [], [], []
wuyuefeng's avatar
wuyuefeng committed
58
59
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
wuyuefeng's avatar
wuyuefeng committed
60
61
62
63
            res_voxel_centers = (
                res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                    self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                        self.voxel_layer.point_cloud_range[0:3])
wuyuefeng's avatar
wuyuefeng committed
64
65
66
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
wuyuefeng's avatar
wuyuefeng committed
67
68
            voxel_centers.append(res_voxel_centers)

wuyuefeng's avatar
wuyuefeng committed
69
70
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
wuyuefeng's avatar
wuyuefeng committed
71
        voxel_centers = torch.cat(voxel_centers, dim=0)
wuyuefeng's avatar
wuyuefeng committed
72
73
74
75
76
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
wuyuefeng's avatar
wuyuefeng committed
77
78
79
80
81
82
83

        voxel_dict = dict(
            voxels=voxels,
            num_points=num_points,
            coors=coors_batch,
            voxel_centers=voxel_centers)
        return voxel_dict
wuyuefeng's avatar
wuyuefeng committed
84
85
86

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
87
                      img_metas,
wuyuefeng's avatar
wuyuefeng committed
88
89
90
91
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None,
                      proposals=None):
zhangwenwei's avatar
zhangwenwei committed
92
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
93
94
95
96
97

        losses = dict()

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
zhangwenwei's avatar
zhangwenwei committed
98
99
            rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d,
                                          img_metas)
wuyuefeng's avatar
wuyuefeng committed
100
101
102
103
104
105
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
zhangwenwei's avatar
zhangwenwei committed
106
            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
107
108
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
wuyuefeng's avatar
wuyuefeng committed
109
110
111
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict,
zhangwenwei's avatar
zhangwenwei committed
112
                                                 img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
113
114
115
                                                 gt_bboxes_3d, gt_labels_3d)

        losses.update(roi_losses)
wuyuefeng's avatar
wuyuefeng committed
116
117
118

        return losses

zhangwenwei's avatar
zhangwenwei committed
119
120
    def simple_test(self, points, img_metas, proposals=None, rescale=False):
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
121
122
123
124

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
            proposal_cfg = self.test_cfg.rpn
zhangwenwei's avatar
zhangwenwei committed
125
            bbox_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
126
            proposal_list = self.rpn_head.get_bboxes(*bbox_inputs)
wuyuefeng's avatar
wuyuefeng committed
127
128
129
        else:
            proposal_list = proposals

zhangwenwei's avatar
zhangwenwei committed
130
        return self.roi_head.simple_test(feats_dict, voxels_dict, img_metas,
wuyuefeng's avatar
wuyuefeng committed
131
                                         proposal_list)