part_aggregation_roi_head.py 13.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3

zhangwenwei's avatar
zhangwenwei committed
4
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
5
6

from mmdet3d.core import AssignResult
zhangwenwei's avatar
zhangwenwei committed
7
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
wuyuefeng's avatar
wuyuefeng committed
8
from mmdet.core import build_assigner, build_sampler
9
from ..builder import HEADS, build_head, build_roi_extractor
wuyuefeng's avatar
wuyuefeng committed
10
11
12
from .base_3droi_head import Base3DRoIHead


zhangwenwei's avatar
zhangwenwei committed
13
@HEADS.register_module()
wuyuefeng's avatar
wuyuefeng committed
14
class PartAggregationROIHead(Base3DRoIHead):
zhangwenwei's avatar
zhangwenwei committed
15
    """Part aggregation roi head for PartA2.
wuyuefeng's avatar
wuyuefeng committed
16
17
18
19
20
21
22
23
24
25

    Args:
        semantic_head (ConfigDict): Config of semantic head.
        num_classes (int): The number of classes.
        seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.
        part_roi_extractor (ConfigDict): Config of part_roi_extractor.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """
wuyuefeng's avatar
wuyuefeng committed
26
27
28
29
30
31
32
33

    def __init__(self,
                 semantic_head,
                 num_classes=3,
                 seg_roi_extractor=None,
                 part_roi_extractor=None,
                 bbox_head=None,
                 train_cfg=None,
34
35
36
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
wuyuefeng's avatar
wuyuefeng committed
37
        super(PartAggregationROIHead, self).__init__(
38
39
40
41
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
42
43
44
45
46
47
48
49
50
51
52
        self.num_classes = num_classes
        assert semantic_head is not None
        self.semantic_head = build_head(semantic_head)

        if seg_roi_extractor is not None:
            self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)
        if part_roi_extractor is not None:
            self.part_roi_extractor = build_roi_extractor(part_roi_extractor)

        self.init_assigner_sampler()

53
54
55
56
57
58
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be setting at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is a deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
wuyuefeng's avatar
wuyuefeng committed
59
60

    def init_mask_head(self):
zhangwenwei's avatar
zhangwenwei committed
61
62
        """Initialize mask head, skip since ``PartAggregationROIHead`` does not
        have one."""
wuyuefeng's avatar
wuyuefeng committed
63
64
65
        pass

    def init_bbox_head(self, bbox_head):
zhangwenwei's avatar
zhangwenwei committed
66
        """Initialize box head."""
wuyuefeng's avatar
wuyuefeng committed
67
68
69
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
zhangwenwei's avatar
zhangwenwei committed
70
        """Initialize assigner and sampler."""
wuyuefeng's avatar
wuyuefeng committed
71
72
73
74
75
76
77
78
79
80
81
82
83
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = build_assigner(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    build_assigner(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)

    @property
    def with_semantic(self):
zhangwenwei's avatar
zhangwenwei committed
84
        """bool: whether the head has semantic branch"""
wuyuefeng's avatar
wuyuefeng committed
85
86
87
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

zhangwenwei's avatar
zhangwenwei committed
88
    def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
89
                      gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
90
        """Training forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
91
92
93
94
95
96

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.
97
                The dictionary should contain the following keys:
wangtai's avatar
wangtai committed
98

zhangwenwei's avatar
zhangwenwei committed
99
                - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes
100
101
                - labels_3d (torch.Tensor): Labels of proposals
                - cls_preds (torch.Tensor): Original scores of proposals
zhangwenwei's avatar
zhangwenwei committed
102
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]):
103
104
105
                GT bboxes of each sample. The bboxes are encapsulated
                by 3D box structures.
            gt_labels_3d (list[LongTensor]): GT labels of each sample.
wuyuefeng's avatar
wuyuefeng committed
106
107
108

        Returns:
            dict: losses from each head.
109
110
111

                - loss_semantic (torch.Tensor): loss of semantic head
                - loss_bbox (torch.Tensor): loss of bboxes
wuyuefeng's avatar
wuyuefeng committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        """
        losses = dict()
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict['seg_features'], voxels_dict, gt_bboxes_3d,
                gt_labels_3d)
            losses.update(semantic_results['loss_semantic'])

        sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(
                feats_dict['seg_features'], semantic_results['part_feats'],
                voxels_dict, sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses

zhangwenwei's avatar
zhangwenwei committed
130
    def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
131
                    **kwargs):
zhangwenwei's avatar
zhangwenwei committed
132
        """Simple testing forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
133

zhangwenwei's avatar
zhangwenwei committed
134
135
136
        Note:
            This function assumes that the batch size is 1

wuyuefeng's avatar
wuyuefeng committed
137
138
139
140
141
142
143
        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
144
            dict: Bbox results of one frame.
wuyuefeng's avatar
wuyuefeng committed
145
146
147
148
149
150
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        assert self.with_semantic

        semantic_results = self.semantic_head(feats_dict['seg_features'])

151
        rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list])
zhangwenwei's avatar
zhangwenwei committed
152
        labels_3d = [res['labels_3d'] for res in proposal_list]
wuyuefeng's avatar
wuyuefeng committed
153
154
155
156
157
158
159
160
161
        cls_preds = [res['cls_preds'] for res in proposal_list]
        bbox_results = self._bbox_forward(feats_dict['seg_features'],
                                          semantic_results['part_feats'],
                                          voxels_dict, rois)

        bbox_list = self.bbox_head.get_bboxes(
            rois,
            bbox_results['cls_score'],
            bbox_results['bbox_pred'],
zhangwenwei's avatar
zhangwenwei committed
162
            labels_3d,
wuyuefeng's avatar
wuyuefeng committed
163
            cls_preds,
zhangwenwei's avatar
zhangwenwei committed
164
            img_metas,
wuyuefeng's avatar
wuyuefeng committed
165
            cfg=self.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
166
167
168
169
170

        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
171
        return bbox_results
wuyuefeng's avatar
wuyuefeng committed
172
173
174

    def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
                            sampling_results):
zhangwenwei's avatar
zhangwenwei committed
175
176
177
178
179
180
181
182
183
184
185
186
        """Forward training function of roi_extractor and bbox_head.

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            sampling_results (:obj:`SamplingResult`): Sampled results used
                for training.

        Returns:
            dict: Forward results including losses and predictions.
        """
wuyuefeng's avatar
wuyuefeng committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(seg_feats, part_feats, voxels_dict,
                                          rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

    def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):
zhangwenwei's avatar
zhangwenwei committed
201
202
        """Forward function of roi_extractor and bbox_head used in both
        training and testing.
wuyuefeng's avatar
wuyuefeng committed
203
204
205
206
207
208
209
210
211
212
213

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            rois (Tensor): Roi boxes.

        Returns:
            dict: Contains predictions of bbox_head and
                features of roi_extractor.
        """
wuyuefeng's avatar
wuyuefeng committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        pooled_seg_feats = self.seg_roi_extractor(seg_feats,
                                                  voxels_dict['voxel_centers'],
                                                  voxels_dict['coors'][..., 0],
                                                  rois)
        pooled_part_feats = self.part_roi_extractor(
            part_feats, voxels_dict['voxel_centers'],
            voxels_dict['coors'][..., 0], rois)
        cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,
                                              pooled_part_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            pooled_seg_feats=pooled_seg_feats,
            pooled_part_feats=pooled_part_feats)
        return bbox_results

    def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
232
        """Assign and sample proposals for training.
zhangwenwei's avatar
zhangwenwei committed
233
234
235
236
237
238
239
240
241
242
243

        Args:
            proposal_list (list[dict]): Proposals produced by RPN.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            list[:obj:`SamplingResult`]: Sampled results of each training
                sample.
        """
wuyuefeng's avatar
wuyuefeng committed
244
245
246
247
        sampling_results = []
        # bbox assign
        for batch_idx in range(len(proposal_list)):
            cur_proposal_list = proposal_list[batch_idx]
zhangwenwei's avatar
zhangwenwei committed
248
249
            cur_boxes = cur_proposal_list['boxes_3d']
            cur_labels_3d = cur_proposal_list['labels_3d']
zhangwenwei's avatar
zhangwenwei committed
250
            cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device)
wuyuefeng's avatar
wuyuefeng committed
251
252
253
            cur_gt_labels = gt_labels_3d[batch_idx]

            batch_num_gts = 0
254
255
256
257
258
259
260
261
            # 0 is bg
            batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
            batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
            # -1 is bg
            batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)

            # each class may have its own assigner
            if isinstance(self.bbox_assigner, list):
wuyuefeng's avatar
wuyuefeng committed
262
263
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
zhangwenwei's avatar
zhangwenwei committed
264
                    pred_per_cls = (cur_labels_3d == i)
wuyuefeng's avatar
wuyuefeng committed
265
                    cur_assign_res = assigner.assign(
266
                        cur_boxes.tensor[pred_per_cls],
zhangwenwei's avatar
zhangwenwei committed
267
                        cur_gt_bboxes.tensor[gt_per_cls],
wuyuefeng's avatar
wuyuefeng committed
268
269
270
271
                        gt_labels=cur_gt_labels[gt_per_cls])
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
Wenwei Zhang's avatar
Wenwei Zhang committed
272
273
                    gt_inds_arange_pad = gt_per_cls.nonzero(
                        as_tuple=False).view(-1) + 1
wuyuefeng's avatar
wuyuefeng committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
294
295
296
                    cur_boxes.tensor,
                    cur_gt_bboxes.tensor,
                    gt_labels=cur_gt_labels)
wuyuefeng's avatar
wuyuefeng committed
297
298
            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
299
                                                       cur_boxes.tensor,
zhangwenwei's avatar
zhangwenwei committed
300
                                                       cur_gt_bboxes.tensor,
wuyuefeng's avatar
wuyuefeng committed
301
302
303
304
305
306
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results

    def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
                                gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
307
        """Train semantic head.
zhangwenwei's avatar
zhangwenwei committed
308
309
310
311
312
313
314
315
316
317
318

        Args:
            x (torch.Tensor): Point-wise semantic features for segmentation
            voxels_dict (dict): Contains information of voxels.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            dict: Segmentation results including losses
        """
wuyuefeng's avatar
wuyuefeng committed
319
320
321
322
323
324
325
        semantic_results = self.semantic_head(x)
        semantic_targets = self.semantic_head.get_targets(
            voxels_dict, gt_bboxes_3d, gt_labels_3d)
        loss_semantic = self.semantic_head.loss(semantic_results,
                                                semantic_targets)
        semantic_results.update(loss_semantic=loss_semantic)
        return semantic_results