part_aggregation_roi_head.py 10.1 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
import torch.nn.functional as F

from mmdet3d.core import AssignResult
zhangwenwei's avatar
zhangwenwei committed
4
from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi
wuyuefeng's avatar
wuyuefeng committed
5
6
7
8
9
10
from mmdet.core import build_assigner, build_sampler
from mmdet.models import HEADS
from ..builder import build_head, build_roi_extractor
from .base_3droi_head import Base3DRoIHead


zhangwenwei's avatar
zhangwenwei committed
11
@HEADS.register_module()
wuyuefeng's avatar
wuyuefeng committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class PartAggregationROIHead(Base3DRoIHead):
    """Part aggregation roi head for PartA2"""

    def __init__(self,
                 semantic_head,
                 num_classes=3,
                 seg_roi_extractor=None,
                 part_roi_extractor=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None):
        super(PartAggregationROIHead, self).__init__(
            bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
        self.num_classes = num_classes
        assert semantic_head is not None
        self.semantic_head = build_head(semantic_head)

        if seg_roi_extractor is not None:
            self.seg_roi_extractor = build_roi_extractor(seg_roi_extractor)
        if part_roi_extractor is not None:
            self.part_roi_extractor = build_roi_extractor(part_roi_extractor)

        self.init_assigner_sampler()

    def init_weights(self, pretrained):
        pass

    def init_mask_head(self):
        pass

    def init_bbox_head(self, bbox_head):
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = build_assigner(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    build_assigner(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)

    @property
    def with_semantic(self):
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

    def forward_train(self, feats_dict, voxels_dict, img_meta, proposal_list,
                      gt_bboxes_3d, gt_labels_3d):
        """Training forward function of PartAggregationROIHead

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.
            gt_bboxes_3d (list[FloatTensor]): GT bboxes of each batch.
            gt_labels_3d (list[LongTensor]): GT labels of each batch.

        Returns:
            dict: losses from each head.
        """
        losses = dict()
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict['seg_features'], voxels_dict, gt_bboxes_3d,
                gt_labels_3d)
            losses.update(semantic_results['loss_semantic'])

        sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(
                feats_dict['seg_features'], semantic_results['part_feats'],
                voxels_dict, sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses

    def simple_test(self, feats_dict, voxels_dict, img_meta, proposal_list,
                    **kwargs):
        """Simple testing forward function of PartAggregationROIHead

zhangwenwei's avatar
zhangwenwei committed
98
99
100
        Note:
            This function assumes that the batch size is 1

wuyuefeng's avatar
wuyuefeng committed
101
102
103
104
105
106
107
        Args:
            feats_dict (dict): Contains features from the first stage.
            voxels_dict (dict): Contains information of voxels.
            img_metas (list[dict]): Meta info of each image.
            proposal_list (list[dict]): Proposal information from rpn.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
108
            dict: Bbox results of one frame.
wuyuefeng's avatar
wuyuefeng committed
109
110
111
112
113
114
        """
        assert self.with_bbox, 'Bbox head must be implemented.'
        assert self.with_semantic

        semantic_results = self.semantic_head(feats_dict['seg_features'])

115
        rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list])
zhangwenwei's avatar
zhangwenwei committed
116
        labels_3d = [res['labels_3d'] for res in proposal_list]
wuyuefeng's avatar
wuyuefeng committed
117
118
119
120
121
122
123
124
125
        cls_preds = [res['cls_preds'] for res in proposal_list]
        bbox_results = self._bbox_forward(feats_dict['seg_features'],
                                          semantic_results['part_feats'],
                                          voxels_dict, rois)

        bbox_list = self.bbox_head.get_bboxes(
            rois,
            bbox_results['cls_score'],
            bbox_results['bbox_pred'],
zhangwenwei's avatar
zhangwenwei committed
126
            labels_3d,
wuyuefeng's avatar
wuyuefeng committed
127
128
129
            cls_preds,
            img_meta,
            cfg=self.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
130
131
132
133
134
135

        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]
wuyuefeng's avatar
wuyuefeng committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    def _bbox_forward_train(self, seg_feats, part_feats, voxels_dict,
                            sampling_results):
        rois = bbox3d2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(seg_feats, part_feats, voxels_dict,
                                          rois)

        bbox_targets = self.bbox_head.get_targets(sampling_results,
                                                  self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results

    def _bbox_forward(self, seg_feats, part_feats, voxels_dict, rois):
        pooled_seg_feats = self.seg_roi_extractor(seg_feats,
                                                  voxels_dict['voxel_centers'],
                                                  voxels_dict['coors'][..., 0],
                                                  rois)
        pooled_part_feats = self.part_roi_extractor(
            part_feats, voxels_dict['voxel_centers'],
            voxels_dict['coors'][..., 0], rois)
        cls_score, bbox_pred = self.bbox_head(pooled_seg_feats,
                                              pooled_part_feats)

        bbox_results = dict(
            cls_score=cls_score,
            bbox_pred=bbox_pred,
            pooled_seg_feats=pooled_seg_feats,
            pooled_part_feats=pooled_part_feats)
        return bbox_results

    def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d):
        sampling_results = []
        # bbox assign
        for batch_idx in range(len(proposal_list)):
            cur_proposal_list = proposal_list[batch_idx]
zhangwenwei's avatar
zhangwenwei committed
175
176
            cur_boxes = cur_proposal_list['boxes_3d']
            cur_labels_3d = cur_proposal_list['labels_3d']
zhangwenwei's avatar
zhangwenwei committed
177
            cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device)
wuyuefeng's avatar
wuyuefeng committed
178
179
180
181
182
183
184
185
186
187
188
            cur_gt_labels = gt_labels_3d[batch_idx]

            batch_num_gts = 0
            batch_gt_indis = cur_gt_labels.new_full((cur_boxes.shape[0], ),
                                                    0)  # 0 is bg
            batch_max_overlaps = cur_boxes.new_zeros(cur_boxes.shape[0])
            batch_gt_labels = cur_gt_labels.new_full((cur_boxes.shape[0], ),
                                                     -1)  # -1 is bg
            if isinstance(self.bbox_assigner, list):  # for multi classes
                for i, assigner in enumerate(self.bbox_assigner):
                    gt_per_cls = (cur_gt_labels == i)
zhangwenwei's avatar
zhangwenwei committed
189
                    pred_per_cls = (cur_labels_3d == i)
wuyuefeng's avatar
wuyuefeng committed
190
191
                    cur_assign_res = assigner.assign(
                        cur_boxes[pred_per_cls],
zhangwenwei's avatar
zhangwenwei committed
192
                        cur_gt_bboxes.tensor[gt_per_cls],
wuyuefeng's avatar
wuyuefeng committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
                        gt_labels=cur_gt_labels[gt_per_cls])
                    # gather assign_results in different class into one result
                    batch_num_gts += cur_assign_res.num_gts
                    # gt inds (1-based)
                    gt_inds_arange_pad = gt_per_cls.nonzero().view(-1) + 1
                    # pad 0 for indice unassigned
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=0)
                    # pad -1 for indice ignore
                    gt_inds_arange_pad = F.pad(
                        gt_inds_arange_pad, (1, 0), mode='constant', value=-1)
                    # convert to 0~gt_num+2 for indices
                    gt_inds_arange_pad += 1
                    # now 0 is bg, >1 is fg in batch_gt_indis
                    batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[
                        cur_assign_res.gt_inds + 1] - 1
                    batch_max_overlaps[
                        pred_per_cls] = cur_assign_res.max_overlaps
                    batch_gt_labels[pred_per_cls] = cur_assign_res.labels

                assign_result = AssignResult(batch_num_gts, batch_gt_indis,
                                             batch_max_overlaps,
                                             batch_gt_labels)
            else:  # for single class
                assign_result = self.bbox_assigner.assign(
zhangwenwei's avatar
zhangwenwei committed
218
                    cur_boxes, cur_gt_bboxes.tensor, gt_labels=cur_gt_labels)
wuyuefeng's avatar
wuyuefeng committed
219
220
221
            # sample boxes
            sampling_result = self.bbox_sampler.sample(assign_result,
                                                       cur_boxes,
zhangwenwei's avatar
zhangwenwei committed
222
                                                       cur_gt_bboxes.tensor,
wuyuefeng's avatar
wuyuefeng committed
223
224
225
226
227
228
229
230
231
232
233
234
235
                                                       cur_gt_labels)
            sampling_results.append(sampling_result)
        return sampling_results

    def _semantic_forward_train(self, x, voxels_dict, gt_bboxes_3d,
                                gt_labels_3d):
        semantic_results = self.semantic_head(x)
        semantic_targets = self.semantic_head.get_targets(
            voxels_dict, gt_bboxes_3d, gt_labels_3d)
        loss_semantic = self.semantic_head.loss(semantic_results,
                                                semantic_targets)
        semantic_results.update(loss_semantic=loss_semantic)
        return semantic_results