"README-EN.md" did not exist on "c00fce56ed7f263cdf528f9a997ab3c7a6702144"
point_rcnn_roi_head.py 13 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, Optional

4
import torch
5
from mmdet.models.task_modules import AssignResult
6
from torch import Tensor
7
8
from torch.nn import functional as F

zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.registry import MODELS, TASK_UTILS
10
11
from mmdet3d.structures import bbox3d2roi
from mmdet3d.utils.typing import InstanceList, SampleList
12
13
14
from .base_3droi_head import Base3DRoIHead


15
@MODELS.register_module()
16
17
18
19
20
class PointRCNNRoIHead(Base3DRoIHead):
    """RoI head for PointRCNN.

    Args:
        bbox_head (dict): Config of bbox_head.
21
        bbox_roi_extractor (dict): Config of RoI extractor.
22
23
        train_cfg (dict): Train configs.
        test_cfg (dict): Test configs.
24
        depth_normalizer (float): Normalize depth feature.
25
26
27
28
29
            Defaults to 70.0.
        init_cfg (dict, optional): Config of initialization. Defaults to None.
    """

    def __init__(self,
30
31
32
33
34
35
                 bbox_head: dict,
                 bbox_roi_extractor: dict,
                 train_cfg: dict,
                 test_cfg: dict,
                 depth_normalizer: dict = 70.0,
                 init_cfg: Optional[dict] = None) -> None:
36
37
        super(PointRCNNRoIHead, self).__init__(
            bbox_head=bbox_head,
38
            bbox_roi_extractor=bbox_roi_extractor,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
        self.depth_normalizer = depth_normalizer

        self.init_assigner_sampler()

    def init_mask_head(self):
        """Initialize maek head."""
        pass

    def init_assigner_sampler(self):
        """Initialize assigner and sampler."""
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
zhangshilong's avatar
zhangshilong committed
56
                self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
57
58
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
zhangshilong's avatar
zhangshilong committed
59
                    TASK_UTILS.build(res) for res in self.train_cfg.assigner
60
                ]
zhangshilong's avatar
zhangshilong committed
61
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
62

63
64
65
66
    def loss(self, feats_dict: Dict, rpn_results_list: InstanceList,
             batch_data_samples: SampleList, **kwargs) -> dict:
        """Perform forward propagation and loss calculation of the detection
        roi on the features of the upstream network.
67
68
69

        Args:
            feats_dict (dict): Contains features from the first stage.
70
71
72
73
74
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
                of rpn head.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
75
76

        Returns:
77
            dict[str, Tensor]: A dictionary of loss components
78
        """
79
80
        features = feats_dict['fp_features']
        fp_points = feats_dict['fp_points']
81
82
83
        point_cls_preds = feats_dict['points_cls_preds']
        sem_scores = point_cls_preds.sigmoid()
        point_scores = sem_scores.max(-1)[0]
84
85
86
87
88
89
90
91
92
93
94
        batch_gt_instances_3d = []
        batch_gt_instances_ignore = []
        for data_sample in batch_data_samples:
            batch_gt_instances_3d.append(data_sample.gt_instances_3d)
            if 'ignored_instances' in data_sample:
                batch_gt_instances_ignore.append(data_sample.ignored_instances)
            else:
                batch_gt_instances_ignore.append(None)
        sample_results = self._assign_and_sample(rpn_results_list,
                                                 batch_gt_instances_3d,
                                                 batch_gt_instances_ignore)
95
96
97

        # concat the depth, semantic features and backbone features
        features = features.transpose(1, 2).contiguous()
98
        point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5
99
100
101
102
103
104
        features_list = [
            point_scores.unsqueeze(2),
            point_depths.unsqueeze(2), features
        ]
        features = torch.cat(features_list, dim=2)

105
        bbox_results = self._bbox_forward_train(features, fp_points,
106
107
108
109
110
111
                                                sample_results)
        losses = dict()
        losses.update(bbox_results['loss_bbox'])

        return losses

112
113
114
115
116
117
118
119
    def predict(self,
                feats_dict: Dict,
                rpn_results_list: InstanceList,
                batch_data_samples: SampleList,
                rescale: bool = False,
                **kwargs) -> InstanceList:
        """Perform forward propagation of the roi head and predict detection
        results on the features of the upstream network.
120
121
122

        Args:
            feats_dict (dict): Contains features from the first stage.
123
124
125
126
127
128
129
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
                of rpn head.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
130
131

        Returns:
132
133
134
135
136
137
138
139
140
141
142
            list[:obj:`InstanceData`]: Detection results of each sample
            after the post process.
            Each item usually contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
              contains a tensor with shape (num_instances, C), where
              C >= 7.
143
        """
144
145
146
147
148
149
150
151
        rois = bbox3d2roi(
            [res['bboxes_3d'].tensor for res in rpn_results_list])
        labels_3d = [res['labels_3d'] for res in rpn_results_list]
        batch_input_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        fp_features = feats_dict['fp_features']
        fp_points = feats_dict['fp_points']
152
153
154
155
        point_cls_preds = feats_dict['points_cls_preds']
        sem_scores = point_cls_preds.sigmoid()
        point_scores = sem_scores.max(-1)[0]

156
157
        features = fp_features.transpose(1, 2).contiguous()
        point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5
158
159
160
161
162
163
164
        features_list = [
            point_scores.unsqueeze(2),
            point_depths.unsqueeze(2), features
        ]

        features = torch.cat(features_list, dim=2)
        batch_size = features.shape[0]
165
166
        bbox_results = self._bbox_forward(features, fp_points, batch_size,
                                          rois)
167
        object_score = bbox_results['cls_score'].sigmoid()
168
        bbox_list = self.bbox_head.get_results(
169
170
171
172
            rois,
            object_score,
            bbox_results['bbox_pred'],
            labels_3d,
173
            batch_input_metas,
174
175
            cfg=self.test_cfg)

176
        return bbox_list
177

178
179
    def _bbox_forward_train(self, features: Tensor, points: Tensor,
                            sampling_results: SampleList) -> dict:
180
181
182
183
184
        """Forward training function of roi_extractor and bbox_head.

        Args:
            features (torch.Tensor): Backbone features with depth and \
                semantic features.
185
            points (torch.Tensor): Point cloud.
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            sampling_results (:obj:`SamplingResult`): Sampled results used
                for training.

        Returns:
            dict: Forward results including losses and predictions.
        """
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        batch_size = features.shape[0]
        bbox_results = self._bbox_forward(features, points, batch_size, rois)
        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)

        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

205
206
    def _bbox_forward(self, features: Tensor, points: Tensor, batch_size: int,
                      rois: Tensor) -> dict:
207
208
209
210
211
212
        """Forward function of roi_extractor and bbox_head used in both
        training and testing.

        Args:
            features (torch.Tensor): Backbone features with depth and
                semantic features.
213
            points (torch.Tensor): Point cloud.
214
215
216
217
218
219
220
            batch_size (int): Batch size.
            rois (torch.Tensor): RoI boxes.

        Returns:
            dict: Contains predictions of bbox_head and
                features of roi_extractor.
        """
221
222
        pooled_point_feats = self.bbox_roi_extractor(features, points,
                                                     batch_size, rois)
223
224
225
226
227

        cls_score, bbox_pred = self.bbox_head(pooled_point_feats)
        bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
        return bbox_results

228
229
230
231
    def _assign_and_sample(
            self, rpn_results_list: InstanceList,
            batch_gt_instances_3d: InstanceList,
            batch_gt_instances_ignore: InstanceList) -> SampleList:
232
233
234
        """Assign and sample proposals for training.

        Args:
235
236
237
238
239
240
241
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
                of rpn head.
            batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
                gt_instances. It usually includes ``bboxes_3d`` and
                ``labels_3d`` attributes.
            batch_gt_instances_ignore (list[:obj:`InstanceData`]): Ignore
                instances of gt bboxes.
242
243
244
245
246
247
248

        Returns:
            list[:obj:`SamplingResult`]: Sampled results of each training
                sample.
        """
        sampling_results = []
        # bbox assign
249
250
251
        for batch_idx in range(len(rpn_results_list)):
            cur_proposal_list = rpn_results_list[batch_idx]
            cur_boxes = cur_proposal_list['bboxes_3d']
252
            cur_labels_3d = cur_proposal_list['labels_3d']
253
254
255
256
257
258
            cur_gt_instances_3d = batch_gt_instances_3d[batch_idx]
            cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\
                bboxes_3d.tensor
            cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx]
            cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to(cur_boxes.device)
            cur_gt_labels = cur_gt_instances_3d.labels_3d
259
260
261
262
263
264
265
266
267
268
269
270
271
            batch_num_gts = 0
            # 0 is bg
            batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
            batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
            # -1 is bg
            batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)

            # each class may have its own assigner
            if isinstance(self.bbox_assigner, list):
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
                    pred_per_cls = (cur_labels_3d == i)
                    cur_assign_res = assigner.assign(
272
273
274
                        cur_proposal_list[pred_per_cls],
                        cur_gt_instances_3d[gt_per_cls],
                        cur_gt_instances_ignore)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
                    gt_inds_arange_pad = gt_per_cls.nonzero(
                        as_tuple=False).view(-1) + 1
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
300
301
                    cur_proposal_list, cur_gt_instances_3d,
                    cur_gt_instances_ignore)
302
303
304
305

            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
                                                       cur_boxes.tensor,
306
                                                       cur_gt_bboxes,
307
308
309
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results