parta2.py 6.22 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
7
8
import torch
import torch.nn.functional as F

from mmdet3d.ops import Voxelization
from mmdet.models import DETECTORS, TwoStageDetector
from .. import builder


9
@DETECTORS.register_module()
wuyuefeng's avatar
wuyuefeng committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class PartA2(TwoStageDetector):

    def __init__(self,
                 voxel_layer,
                 voxel_encoder,
                 middle_encoder,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(PartA2, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
        )
        self.voxel_layer = Voxelization(**voxel_layer)
        self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
        self.middle_encoder = builder.build_middle_encoder(middle_encoder)

zhangwenwei's avatar
zhangwenwei committed
36
    def extract_feat(self, points, img_metas):
wuyuefeng's avatar
wuyuefeng committed
37
38
39
40
41
42
43
        voxel_dict = self.voxelize(points)
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        feats_dict = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                         batch_size)
wuyuefeng's avatar
wuyuefeng committed
44
45
46
47
48
49
50
51
        x = self.backbone(feats_dict['spatial_features'])
        if self.with_neck:
            neck_feats = self.neck(x)
            feats_dict.update({'neck_feats': neck_feats})
        return feats_dict, voxel_dict

    @torch.no_grad()
    def voxelize(self, points):
wuyuefeng's avatar
wuyuefeng committed
52
        voxels, coors, num_points, voxel_centers = [], [], [], []
wuyuefeng's avatar
wuyuefeng committed
53
54
        for res in points:
            res_voxels, res_coors, res_num_points = self.voxel_layer(res)
wuyuefeng's avatar
wuyuefeng committed
55
56
57
58
            res_voxel_centers = (
                res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
                    self.voxel_layer.voxel_size) + res_voxels.new_tensor(
                        self.voxel_layer.point_cloud_range[0:3])
wuyuefeng's avatar
wuyuefeng committed
59
60
61
            voxels.append(res_voxels)
            coors.append(res_coors)
            num_points.append(res_num_points)
wuyuefeng's avatar
wuyuefeng committed
62
63
            voxel_centers.append(res_voxel_centers)

wuyuefeng's avatar
wuyuefeng committed
64
65
        voxels = torch.cat(voxels, dim=0)
        num_points = torch.cat(num_points, dim=0)
wuyuefeng's avatar
wuyuefeng committed
66
        voxel_centers = torch.cat(voxel_centers, dim=0)
wuyuefeng's avatar
wuyuefeng committed
67
68
69
70
71
        coors_batch = []
        for i, coor in enumerate(coors):
            coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
            coors_batch.append(coor_pad)
        coors_batch = torch.cat(coors_batch, dim=0)
wuyuefeng's avatar
wuyuefeng committed
72
73
74
75
76
77
78

        voxel_dict = dict(
            voxels=voxels,
            num_points=num_points,
            coors=coors_batch,
            voxel_centers=voxel_centers)
        return voxel_dict
wuyuefeng's avatar
wuyuefeng committed
79
80
81

    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
82
                      img_metas,
wuyuefeng's avatar
wuyuefeng committed
83
84
85
86
                      gt_bboxes_3d,
                      gt_labels_3d,
                      gt_bboxes_ignore=None,
                      proposals=None):
zhangwenwei's avatar
zhangwenwei committed
87
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
88
89
90
91
92

        losses = dict()

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
zhangwenwei's avatar
zhangwenwei committed
93
94
            rpn_loss_inputs = rpn_outs + (gt_bboxes_3d, gt_labels_3d,
                                          img_metas)
wuyuefeng's avatar
wuyuefeng committed
95
96
97
98
99
100
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)

            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
zhangwenwei's avatar
zhangwenwei committed
101
            proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
102
103
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
wuyuefeng's avatar
wuyuefeng committed
104
105
106
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(feats_dict, voxels_dict,
zhangwenwei's avatar
zhangwenwei committed
107
                                                 img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
108
109
110
                                                 gt_bboxes_3d, gt_labels_3d)

        losses.update(roi_losses)
wuyuefeng's avatar
wuyuefeng committed
111
112
113

        return losses

zhangwenwei's avatar
zhangwenwei committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def forward_test(self, points, img_metas, imgs=None, **kwargs):
        """
        Args:
            points (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxC,
                which contains all points in the batch.
            img_metas (List[List[dict]]): the outer list indicates test-time
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
        """
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

        num_augs = len(points)
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
                    len(points), len(img_metas)))
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
        samples_per_gpu = len(points[0])
        assert samples_per_gpu == 1

        if num_augs == 1:
            return self.simple_test(points[0], img_metas[0], **kwargs)
        else:
            return self.aug_test(points, img_metas, **kwargs)
wuyuefeng's avatar
wuyuefeng committed
142
143
144
145
146
147
148

    def forward(self, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(**kwargs)
        else:
            return self.forward_test(**kwargs)

zhangwenwei's avatar
zhangwenwei committed
149
150
    def simple_test(self, points, img_metas, proposals=None, rescale=False):
        feats_dict, voxels_dict = self.extract_feat(points, img_metas)
wuyuefeng's avatar
wuyuefeng committed
151
152
153
154

        if self.with_rpn:
            rpn_outs = self.rpn_head(feats_dict['neck_feats'])
            proposal_cfg = self.test_cfg.rpn
zhangwenwei's avatar
zhangwenwei committed
155
            bbox_inputs = rpn_outs + (img_metas, proposal_cfg)
wuyuefeng's avatar
wuyuefeng committed
156
            proposal_list = self.rpn_head.get_bboxes(*bbox_inputs)
wuyuefeng's avatar
wuyuefeng committed
157
158
159
        else:
            proposal_list = proposals

zhangwenwei's avatar
zhangwenwei committed
160
        return self.roi_head.simple_test(feats_dict, voxels_dict, img_metas,
wuyuefeng's avatar
wuyuefeng committed
161
                                         proposal_list)
zhangwenwei's avatar
zhangwenwei committed
162
163
164

    def aug_test(self, **kwargs):
        raise NotImplementedError