part_aggregation_roi_head.py 13.1 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
2
3

from mmdet3d.core import AssignResult
zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8
9
10
from mmdet.core import build_assigner, build_sampler
from mmdet.models import HEADS
from ..builder import build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead


zhangwenwei's avatar
zhangwenwei committed
11
@HEADS.register_module()
wuyuefeng's avatar
wuyuefeng committed
12
class PartAggregationROIHead(Base3DRoIHead):
zhangwenwei's avatar
zhangwenwei committed
13
    """Part aggregation roi head for PartA2.
wuyuefeng's avatar
wuyuefeng committed
14
15
16
17
18
19
20
21
22
23

    Args:
        semantic_head (ConfigDict): Config of semantic head.
        num_classes (int): The number of classes.
        seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.
        part_roi_extractor (ConfigDict): Config of part_roi_extractor.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """
wuyuefeng's avatar
wuyuefeng committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    def __init__(self,
                 semantic_head,
                 num_classes=3,
                 seg_roi_extractor=None,
                 part_roi_extractor=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None):
        super(PartAggregationROIHead, self).__init__(
            bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
        self.num_classes = num_classes
        assert semantic_head is not None
        self.semantic_head = build_head(semantic_head)

        if seg_roi_extractor is not None:
            self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)
        if part_roi_extractor is not None:
            self.part_roi_extractor = build_roi_extractor(part_roi_extractor)

        self.init_assigner_sampler()

    def init_weights(self, pretrained):
zhangwenwei's avatar
zhangwenwei committed
47
48
        """Initialize weights, skip since ``PartAggregationROIHead`` does not
        need to initialize weights."""
wuyuefeng's avatar
wuyuefeng committed
49
50
51
        pass

    def init_mask_head(self):
zhangwenwei's avatar
zhangwenwei committed
52
53
        """Initialize mask head, skip since ``PartAggregationROIHead`` does not
        have one."""
wuyuefeng's avatar
wuyuefeng committed
54
55
56
        pass

    def init_bbox_head(self, bbox_head):
zhangwenwei's avatar
zhangwenwei committed
57
        """Initialize box head."""
wuyuefeng's avatar
wuyuefeng committed
58
59
60
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
zhangwenwei's avatar
zhangwenwei committed
61
        """Initialize assigner and sampler."""
wuyuefeng's avatar
wuyuefeng committed
62
63
64
65
66
67
68
69
70
71
72
73
74
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = build_assigner(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    build_assigner(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)

    @property
    def with_semantic(self):
zhangwenwei's avatar
zhangwenwei committed
75
        """bool: whether the head has semantic branch"""
wuyuefeng's avatar
wuyuefeng committed
76
77
78
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

zhangwenwei's avatar
zhangwenwei committed
79
    def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
80
                      gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
81
        """Training forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
82
83
84
85
86
87

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.
88
                The dictionary should contain the following keys:
zhangwenwei's avatar
zhangwenwei committed
89
                - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes
90
91
                - labels_3d (torch.Tensor): Labels of proposals
                - cls_preds (torch.Tensor): Original scores of proposals
zhangwenwei's avatar
zhangwenwei committed
92
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]):
93
94
95
                GT bboxes of each sample. The bboxes are encapsulated
                by 3D box structures.
            gt_labels_3d (list[LongTensor]): GT labels of each sample.
wuyuefeng's avatar
wuyuefeng committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        Returns:
            dict: losses from each head.
        """
        losses = dict()
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict['seg_features'], voxels_dict, gt_bboxes_3d,
                gt_labels_3d)
            losses.update(semantic_results['loss_semantic'])

        sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(
                feats_dict['seg_features'], semantic_results['part_feats'],
                voxels_dict, sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses

zhangwenwei's avatar
zhangwenwei committed
117
    def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
118
                    **kwargs):
zhangwenwei's avatar
zhangwenwei committed
119
        """Simple testing forward function of PartAggregationROIHead.
wuyuefeng's avatar
wuyuefeng committed
120

zhangwenwei's avatar
zhangwenwei committed
121
122
123
        Note:
            This function assumes that the batch size is 1

wuyuefeng's avatar
wuyuefeng committed
124
125
126
127
128
129
130
        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
131
            dict: Bbox results of one frame.
wuyuefeng's avatar
wuyuefeng committed
132
133
134
135
136
137
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        assert self.with_semantic

        semantic_results = self.semantic_head(feats_dict['seg_features'])

138
        rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list])
zhangwenwei's avatar
zhangwenwei committed
139
        labels_3d = [res['labels_3d'] for res in proposal_list]
wuyuefeng's avatar
wuyuefeng committed
140
141
142
143
144
145
146
147
148
        cls_preds = [res['cls_preds'] for res in proposal_list]
        bbox_results = self._bbox_forward(feats_dict['seg_features'],
                                          semantic_results['part_feats'],
                                          voxels_dict, rois)

        bbox_list = self.bbox_head.get_bboxes(
            rois,
            bbox_results['cls_score'],
            bbox_results['bbox_pred'],
zhangwenwei's avatar
zhangwenwei committed
149
            labels_3d,
wuyuefeng's avatar
wuyuefeng committed
150
            cls_preds,
zhangwenwei's avatar
zhangwenwei committed
151
            img_metas,
wuyuefeng's avatar
wuyuefeng committed
152
            cfg=self.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
153
154
155
156
157
158

        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
wuyuefeng's avatar
wuyuefeng committed
159
160
161

    def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
                            sampling_results):
zhangwenwei's avatar
zhangwenwei committed
162
163
164
165
166
167
168
169
170
171
172
173
        """Forward training function of roi_extractor and bbox_head.

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            sampling_results (:obj:`SamplingResult`): Sampled results used
                for training.

        Returns:
            dict: Forward results including losses and predictions.
        """
wuyuefeng's avatar
wuyuefeng committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(seg_feats, part_feats, voxels_dict,
                                          rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

    def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):
zhangwenwei's avatar
zhangwenwei committed
188
189
        """Forward function of roi_extractor and bbox_head used in both
        training and testing.
wuyuefeng's avatar
wuyuefeng committed
190
191
192
193
194
195
196
197
198
199
200

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            rois (Tensor): Roi boxes.

        Returns:
            dict: Contains predictions of bbox_head and
                features of roi_extractor.
        """
wuyuefeng's avatar
wuyuefeng committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        pooled_seg_feats = self.seg_roi_extractor(seg_feats,
                                                  voxels_dict['voxel_centers'],
                                                  voxels_dict['coors'][..., 0],
                                                  rois)
        pooled_part_feats = self.part_roi_extractor(
            part_feats, voxels_dict['voxel_centers'],
            voxels_dict['coors'][..., 0], rois)
        cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,
                                              pooled_part_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            pooled_seg_feats=pooled_seg_feats,
            pooled_part_feats=pooled_part_feats)
        return bbox_results

    def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
219
        """Assign and sample proposals for training.
zhangwenwei's avatar
zhangwenwei committed
220
221
222
223
224
225
226
227
228
229
230

        Args:
            proposal_list (list[dict]): Proposals produced by RPN.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            list[:obj:`SamplingResult`]: Sampled results of each training
                sample.
        """
wuyuefeng's avatar
wuyuefeng committed
231
232
233
234
        sampling_results = []
        # bbox assign
        for batch_idx in range(len(proposal_list)):
            cur_proposal_list = proposal_list[batch_idx]
zhangwenwei's avatar
zhangwenwei committed
235
236
            cur_boxes = cur_proposal_list['boxes_3d']
            cur_labels_3d = cur_proposal_list['labels_3d']
zhangwenwei's avatar
zhangwenwei committed
237
            cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device)
wuyuefeng's avatar
wuyuefeng committed
238
239
240
            cur_gt_labels = gt_labels_3d[batch_idx]

            batch_num_gts = 0
241
242
243
244
245
246
247
248
            # 0 is bg
            batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
            batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
            # -1 is bg
            batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)

            # each class may have its own assigner
            if isinstance(self.bbox_assigner, list):
wuyuefeng's avatar
wuyuefeng committed
249
250
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
zhangwenwei's avatar
zhangwenwei committed
251
                    pred_per_cls = (cur_labels_3d == i)
wuyuefeng's avatar
wuyuefeng committed
252
                    cur_assign_res = assigner.assign(
253
                        cur_boxes.tensor[pred_per_cls],
zhangwenwei's avatar
zhangwenwei committed
254
                        cur_gt_bboxes.tensor[gt_per_cls],
wuyuefeng's avatar
wuyuefeng committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                        gt_labels=cur_gt_labels[gt_per_cls])
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
                    gt_inds_arange_pad = gt_per_cls.nonzero().view(-1) + 1
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
280
281
282
                    cur_boxes.tensor,
                    cur_gt_bboxes.tensor,
                    gt_labels=cur_gt_labels)
wuyuefeng's avatar
wuyuefeng committed
283
284
            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
285
                                                       cur_boxes.tensor,
zhangwenwei's avatar
zhangwenwei committed
286
                                                       cur_gt_bboxes.tensor,
wuyuefeng's avatar
wuyuefeng committed
287
288
289
290
291
292
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results

    def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
                                gt_labels_3d):
zhangwenwei's avatar
zhangwenwei committed
293
        """Train semantic head.
zhangwenwei's avatar
zhangwenwei committed
294
295
296
297
298
299
300
301
302
303
304

        Args:
            x (torch.Tensor): Point-wise semantic features for segmentation
            voxels_dict (dict): Contains information of voxels.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels

        Returns:
            dict: Segmentation results including losses
        """
wuyuefeng's avatar
wuyuefeng committed
305
306
307
308
309
310
311
        semantic_results = self.semantic_head(x)
        semantic_targets = self.semantic_head.get_targets(
            voxels_dict, gt_bboxes_3d, gt_labels_3d)
        loss_semantic = self.semantic_head.loss(semantic_results,
                                                semantic_targets)
        semantic_results.update(loss_semantic=loss_semantic)
        return semantic_results