part_aggregation_roi_head.py 11.3 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
import torch.nn.functional as F

from mmdet3d.core import AssignResult
zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8
9
10
from mmdet.core import build_assigner, build_sampler
from mmdet.models import HEADS
from ..builder import build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead


zhangwenwei's avatar
zhangwenwei committed
11
@HEADS.register_module()
wuyuefeng's avatar
wuyuefeng committed
12
class PartAggregationROIHead(Base3DRoIHead):
wuyuefeng's avatar
wuyuefeng committed
13
14
15
16
17
18
19
20
21
22
23
    """Part aggregation roi head for PartA2

    Args:
        semantic_head (ConfigDict): Config of semantic head.
        num_classes (int): The number of classes.
        seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.
        part_roi_extractor (ConfigDict): Config of part_roi_extractor.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """
wuyuefeng's avatar
wuyuefeng committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def __init__(self,
                 semantic_head,
                 num_classes=3,
                 seg_roi_extractor=None,
                 part_roi_extractor=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None):
        super(PartAggregationROIHead, self).__init__(
            bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
        self.num_classes = num_classes
        assert semantic_head is not None
        self.semantic_head = build_head(semantic_head)

        if seg_roi_extractor is not None:
            self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)
        if part_roi_extractor is not None:
            self.part_roi_extractor = build_roi_extractor(part_roi_extractor)

        self.init_assigner_sampler()

    def init_weights(self, pretrained):
        pass

    def init_mask_head(self):
        pass

    def init_bbox_head(self, bbox_head):
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = build_assigner(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    build_assigner(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)

    @property
    def with_semantic(self):
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

zhangwenwei's avatar
zhangwenwei committed
72
    def forward_train(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
73
74
75
76
77
78
79
80
                      gt_bboxes_3d, gt_labels_3d):
        """Training forward function of PartAggregationROIHead

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.
81
                The dictionary should contain the following keys:
zhangwenwei's avatar
zhangwenwei committed
82
                - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes
83
84
                - labels_3d (torch.Tensor): Labels of proposals
                - cls_preds (torch.Tensor): Original scores of proposals
zhangwenwei's avatar
zhangwenwei committed
85
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]):
86
87
88
                GT bboxes of each sample. The bboxes are encapsulated
                by 3D box structures.
            gt_labels_3d (list[LongTensor]): GT labels of each sample.
wuyuefeng's avatar
wuyuefeng committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

        Returns:
            dict: losses from each head.
        """
        losses = dict()
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict['seg_features'], voxels_dict, gt_bboxes_3d,
                gt_labels_3d)
            losses.update(semantic_results['loss_semantic'])

        sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(
                feats_dict['seg_features'], semantic_results['part_feats'],
                voxels_dict, sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses

zhangwenwei's avatar
zhangwenwei committed
110
    def simple_test(self, feats_dict, voxels_dict, img_metas, proposal_list,
wuyuefeng's avatar
wuyuefeng committed
111
112
113
                    **kwargs):
        """Simple testing forward function of PartAggregationROIHead

zhangwenwei's avatar
zhangwenwei committed
114
115
116
        Note:
            This function assumes that the batch size is 1

wuyuefeng's avatar
wuyuefeng committed
117
118
119
120
121
122
123
        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
124
            dict: Bbox results of one frame.
wuyuefeng's avatar
wuyuefeng committed
125
126
127
128
129
130
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        assert self.with_semantic

        semantic_results = self.semantic_head(feats_dict['seg_features'])

131
        rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list])
zhangwenwei's avatar
zhangwenwei committed
132
        labels_3d = [res['labels_3d'] for res in proposal_list]
wuyuefeng's avatar
wuyuefeng committed
133
134
135
136
137
138
139
140
141
        cls_preds = [res['cls_preds'] for res in proposal_list]
        bbox_results = self._bbox_forward(feats_dict['seg_features'],
                                          semantic_results['part_feats'],
                                          voxels_dict, rois)

        bbox_list = self.bbox_head.get_bboxes(
            rois,
            bbox_results['cls_score'],
            bbox_results['bbox_pred'],
zhangwenwei's avatar
zhangwenwei committed
142
            labels_3d,
wuyuefeng's avatar
wuyuefeng committed
143
            cls_preds,
zhangwenwei's avatar
zhangwenwei committed
144
            img_metas,
wuyuefeng's avatar
wuyuefeng committed
145
            cfg=self.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
146
147
148
149
150
151

        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
wuyuefeng's avatar
wuyuefeng committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
                            sampling_results):
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(seg_feats, part_feats, voxels_dict,
                                          rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

    def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):
wuyuefeng's avatar
wuyuefeng committed
169
170
171
172
173
174
175
176
177
178
179
180
        """Forward function of roi_extractor and bbox_head.

        Args:
            seg_feats (torch.Tensor): Point-wise semantic features.
            part_feats (torch.Tensor): Point-wise part prediction features.
            voxels_dict (dict): Contains information of voxels.
            rois (Tensor): Roi boxes.

        Returns:
            dict: Contains predictions of bbox_head and
                features of roi_extractor.
        """
wuyuefeng's avatar
wuyuefeng committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        pooled_seg_feats = self.seg_roi_extractor(seg_feats,
                                                  voxels_dict['voxel_centers'],
                                                  voxels_dict['coors'][..., 0],
                                                  rois)
        pooled_part_feats = self.part_roi_extractor(
            part_feats, voxels_dict['voxel_centers'],
            voxels_dict['coors'][..., 0], rois)
        cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,
                                              pooled_part_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            pooled_seg_feats=pooled_seg_feats,
            pooled_part_feats=pooled_part_feats)
        return bbox_results

    def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
        sampling_results = []
        # bbox assign
        for batch_idx in range(len(proposal_list)):
            cur_proposal_list = proposal_list[batch_idx]
zhangwenwei's avatar
zhangwenwei committed
203
204
            cur_boxes = cur_proposal_list['boxes_3d']
            cur_labels_3d = cur_proposal_list['labels_3d']
zhangwenwei's avatar
zhangwenwei committed
205
            cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device)
wuyuefeng's avatar
wuyuefeng committed
206
207
208
            cur_gt_labels = gt_labels_3d[batch_idx]

            batch_num_gts = 0
209
210
211
212
213
214
215
216
            # 0 is bg
            batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0)
            batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes))
            # -1 is bg
            batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1)

            # each class may have its own assigner
            if isinstance(self.bbox_assigner, list):
wuyuefeng's avatar
wuyuefeng committed
217
218
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
zhangwenwei's avatar
zhangwenwei committed
219
                    pred_per_cls = (cur_labels_3d == i)
wuyuefeng's avatar
wuyuefeng committed
220
                    cur_assign_res = assigner.assign(
221
                        cur_boxes.tensor[pred_per_cls],
zhangwenwei's avatar
zhangwenwei committed
222
                        cur_gt_bboxes.tensor[gt_per_cls],
wuyuefeng's avatar
wuyuefeng committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                        gt_labels=cur_gt_labels[gt_per_cls])
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
                    gt_inds_arange_pad = gt_per_cls.nonzero().view(-1) + 1
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
248
249
250
                    cur_boxes.tensor,
                    cur_gt_bboxes.tensor,
                    gt_labels=cur_gt_labels)
wuyuefeng's avatar
wuyuefeng committed
251
252
            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
253
                                                       cur_boxes.tensor,
zhangwenwei's avatar
zhangwenwei committed
254
                                                       cur_gt_bboxes.tensor,
wuyuefeng's avatar
wuyuefeng committed
255
256
257
258
259
260
261
262
263
264
265
266
267
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results

    def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
                                gt_labels_3d):
        semantic_results = self.semantic_head(x)
        semantic_targets = self.semantic_head.get_targets(
            voxels_dict, gt_bboxes_3d, gt_labels_3d)
        loss_semantic = self.semantic_head.loss(semantic_results,
                                                semantic_targets)
        semantic_results.update(loss_semantic=loss_semantic)
        return semantic_results