part_aggregation_roi_head.py 13.2 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
2
3

from mmdet3d.core import AssignResult
zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8
9
10
from mmdet.core import build_assigner, build_sampler
from mmdet.models import HEADS
from ..builder import build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead


zhangwenwei's avatar
zhangwenwei committed
11
@HEADS.register_module()
wuyuefeng's avatar
wuyuefeng committed
12
class PartAggregationROIHead(Base3DRoIHead):
zhangwenwei's avatar
zhangwenwei committed
13
    """Part aggregation roi head for PartA2.
wuyuefeng's avatar
wuyuefeng committed
14
15
16
17
18
19
20
21
22
23

    Args:
        semantic_head (ConfigDict): Config of semantic head.
        num_classes (int): The number of classes.
        seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.
        part_roi_extractor (ConfigDict): Config of part_roi_extractor.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """
wuyuefeng's avatar
wuyuefeng committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    def __init__(self,
                 semantic_head,
                 num_classes=3,
                 seg_roi_extractor=None,
                 part_roi_extractor=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None):
        super(PartAggregationROIHead, self).__init__(
            bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
        self.num_classes = num_classes
        assert semantic_head is not None
        self.semantic_head = build_head(semantic_head)

        if seg_roi_extractor is not None:
            self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)
        if part_roi_extractor is not None:
            self.part_roi_extractor = build_roi_extractor(part_roi_extractor)

        self.init_assigner_sampler()

    def init_weights(self, pretrained):
zhangwenwei's avatar
zhangwenwei committed
47
48
        """Initialize weights, skip since ``PartAggregationROIHead`` does not
        need to initialize weights."""
wuyuefeng's avatar
wuyuefeng committed
49
50
51
        pass

    def init_mask_head(self):
zhangwenwei's avatar
zhangwenwei committed
52
53
        """Initialize mask head, skip since ``PartAggregationROIHead`` does not
        have one."""
wuyuefeng's avatar
wuyuefeng committed
54
55
56
        pass

    def init_bbox_head(self, bbox_head):
zhangwenwei's avatar
zhangwenwei committed
57
        """Initialize box head."""
wuyuefeng's avatar
wuyuefeng committed
58
59
60
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
zhangwenwei's avatar
zhangwenwei committed
61
        """Initialize assigner and sampler."""
wuyuefeng's avatar
wuyuefeng committed
62
63
64
65
66
67
68
69
70
71
72
73
74
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = build_assigner(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    build_assigner(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)

    @property
    def with_semantic(self):
zhangwenwei's avatar
zhangwenwei committed
75
        """bool: whether the head has semantic branch"""
wuyuefeng's avatar
wuyuefeng committed
76
77
78
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

zhangwenwei's avatar
zhangwenwei committed
79
    def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
80
                      gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
81
        """Training forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
82
83
84
85
86
87

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.
88
                The dictionary should contain the following keys:
wangtai's avatar
wangtai committed
89

zhangwenwei's avatar
zhangwenwei committed
90
                - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes
91
92
                - labels_3d (torch.Tensor): Labels of proposals
                - cls_preds (torch.Tensor): Original scores of proposals
zhangwenwei's avatar
zhangwenwei committed
93
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]):
94
95
96
                GT bboxes of each sample. The bboxes are encapsulated
                by 3D box structures.
            gt_labels_3d (list[LongTensor]): GT labels of each sample.
wuyuefeng's avatar
wuyuefeng committed
97
98
99

        Returns:
            dict: losses from each head.
100
101
102

                - loss_semantic (torch.Tensor): loss of semantic head
                - loss_bbox (torch.Tensor): loss of bboxes
wuyuefeng's avatar
wuyuefeng committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        """
        losses = dict()
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict['seg_features'], voxels_dict, gt_bboxes_3d,
                gt_labels_3d)
            losses.update(semantic_results['loss_semantic'])

        sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(
                feats_dict['seg_features'], semantic_results['part_feats'],
                voxels_dict, sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses

zhangwenwei's avatar
zhangwenwei committed
121
    def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
122
                    **kwargs):
zhangwenwei's avatar
zhangwenwei committed
123
        """Simple testing forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
124

zhangwenwei's avatar
zhangwenwei committed
125
126
127
        Note:
            This function assumes that the batch size is 1

wuyuefeng's avatar
wuyuefeng committed
128
129
130
131
132
133
134
        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
135
            dict: Bbox results of one frame.
wuyuefeng's avatar
wuyuefeng committed
136
137
138
139
140
141
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        assert self.with_semantic

        semantic_results = self.semantic_head(feats_dict['seg_features'])

142
        rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list])
zhangwenwei's avatar
zhangwenwei committed
143
        labels_3d = [res['labels_3d'] for res in proposal_list]
wuyuefeng's avatar
wuyuefeng committed
144
145
146
147
148
149
150
151
152
        cls_preds = [res['cls_preds'] for res in proposal_list]
        bbox_results = self._bbox_forward(feats_dict['seg_features'],
                                          semantic_results['part_feats'],
                                          voxels_dict, rois)

        bbox_list = self.bbox_head.get_bboxes(
            rois,
            bbox_results['cls_score'],
            bbox_results['bbox_pred'],
zhangwenwei's avatar
zhangwenwei committed
153
            labels_3d,
wuyuefeng's avatar
wuyuefeng committed
154
            cls_preds,
zhangwenwei's avatar
zhangwenwei committed
155
            img_metas,
wuyuefeng's avatar
wuyuefeng committed
156
            cfg=self.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
157
158
159
160
161
162

        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
wuyuefeng's avatar
wuyuefeng committed
163
164
165

    def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
                            sampling_results):
zhangwenwei's avatar
zhangwenwei committed
166
167
168
169
170
171
172
173
174
175
176
177
        """Forward training function of roi_extractor and bbox_head.

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            sampling_results (:obj:`SamplingResult`): Sampled results used
                for training.

        Returns:
            dict: Forward results including losses and predictions.
        """
wuyuefeng's avatar
wuyuefeng committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(seg_feats, part_feats, voxels_dict,
                                          rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

    def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):
zhangwenwei's avatar
zhangwenwei committed
192
193
        """Forward function of roi_extractor and bbox_head used in both
        training and testing.
wuyuefeng's avatar
wuyuefeng committed
194
195
196
197
198
199
200
201
202
203
204

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            rois (Tensor): Roi boxes.

        Returns:
            dict: Contains predictions of bbox_head and
                features of roi_extractor.
        """
wuyuefeng's avatar
wuyuefeng committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        pooled_seg_feats = self.seg_roi_extractor(seg_feats,
                                                  voxels_dict['voxel_centers'],
                                                  voxels_dict['coors'][..., 0],
                                                  rois)
        pooled_part_feats = self.part_roi_extractor(
            part_feats, voxels_dict['voxel_centers'],
            voxels_dict['coors'][..., 0], rois)
        cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,
                                              pooled_part_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            pooled_seg_feats=pooled_seg_feats,
            pooled_part_feats=pooled_part_feats)
        return bbox_results

    def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
223
        """Assign and sample proposals for training.
zhangwenwei's avatar
zhangwenwei committed
224
225
226
227
228
229
230
231
232
233
234

        Args:
            proposal_list (list[dict]): Proposals produced by RPN.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            list[:obj:`SamplingResult`]: Sampled results of each training
                sample.
        """
wuyuefeng's avatar
wuyuefeng committed
235
236
237
238
        sampling_results = []
        # bbox assign
        for batch_idx in range(len(proposal_list)):
            cur_proposal_list = proposal_list[batch_idx]
zhangwenwei's avatar
zhangwenwei committed
239
240
            cur_boxes = cur_proposal_list['boxes_3d']
            cur_labels_3d = cur_proposal_list['labels_3d']
zhangwenwei's avatar
zhangwenwei committed
241
            cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device)
wuyuefeng's avatar
wuyuefeng committed
242
243
244
            cur_gt_labels = gt_labels_3d[batch_idx]

            batch_num_gts = 0
245
246
247
248
249
250
251
252
            # 0 is bg
            batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
            batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
            # -1 is bg
            batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)

            # each class may have its own assigner
            if isinstance(self.bbox_assigner, list):
wuyuefeng's avatar
wuyuefeng committed
253
254
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
zhangwenwei's avatar
zhangwenwei committed
255
                    pred_per_cls = (cur_labels_3d == i)
wuyuefeng's avatar
wuyuefeng committed
256
                    cur_assign_res = assigner.assign(
257
                        cur_boxes.tensor[pred_per_cls],
zhangwenwei's avatar
zhangwenwei committed
258
                        cur_gt_bboxes.tensor[gt_per_cls],
wuyuefeng's avatar
wuyuefeng committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
                        gt_labels=cur_gt_labels[gt_per_cls])
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
                    gt_inds_arange_pad = gt_per_cls.nonzero().view(-1) + 1
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
284
285
286
                    cur_boxes.tensor,
                    cur_gt_bboxes.tensor,
                    gt_labels=cur_gt_labels)
wuyuefeng's avatar
wuyuefeng committed
287
288
            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
289
                                                       cur_boxes.tensor,
zhangwenwei's avatar
zhangwenwei committed
290
                                                       cur_gt_bboxes.tensor,
wuyuefeng's avatar
wuyuefeng committed
291
292
293
294
295
296
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results

    def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
                                gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
297
        """Train semantic head.
zhangwenwei's avatar
zhangwenwei committed
298
299
300
301
302
303
304
305
306
307
308

        Args:
            x (torch.Tensor): Point-wise semantic features for segmentation
            voxels_dict (dict): Contains information of voxels.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            dict: Segmentation results including losses
        """
wuyuefeng's avatar
wuyuefeng committed
309
310
311
312
313
314
315
        semantic_results = self.semantic_head(x)
        semantic_targets = self.semantic_head.get_targets(
            voxels_dict, gt_bboxes_3d, gt_labels_3d)
        loss_semantic = self.semantic_head.loss(semantic_results,
                                                semantic_targets)
        semantic_results.update(loss_semantic=loss_semantic)
        return semantic_results