point_rcnn.py 4.97 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import torch

4
from mmdet3d.registry import MODELS
5
6
7
from .two_stage import TwoStage3DDetector


8
@MODELS.register_module()
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class PointRCNN(TwoStage3DDetector):
    r"""PointRCNN detector.

    Please refer to the `PointRCNN <https://arxiv.org/abs/1812.04244>`_

    Args:
        backbone (dict): Config dict of detector's backbone.
        neck (dict, optional): Config dict of neck. Defaults to None.
        rpn_head (dict, optional): Config of RPN head. Defaults to None.
        roi_head (dict, optional): Config of ROI head. Defaults to None.
        train_cfg (dict, optional): Train configs. Defaults to None.
        test_cfg (dict, optional): Test configs. Defaults to None.
        pretrained (str, optional): Model pretrained path. Defaults to None.
        init_cfg (dict, optional): Config of initialization. Defaults to None.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 rpn_head=None,
                 roi_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
        super(PointRCNN, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
            init_cfg=init_cfg)

    def extract_feat(self, points):
        """Directly extract features from the backbone+neck.

        Args:
            points (torch.Tensor): Input points.

        Returns:
            dict: Features from the backbone+neck
        """
        x = self.backbone(points)

        if self.with_neck:
            x = self.neck(x)
        return x

jshilong's avatar
jshilong committed
59
    def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d):
60
61
62
63
        """Forward of training.

        Args:
            points (list[torch.Tensor]): Points of each batch.
jshilong's avatar
jshilong committed
64
            input_metas (list[dict]): Meta information of each sample.
65
66
67
68
69
70
71
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
            gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.

        Returns:
            dict: Losses.
        """
        losses = dict()
jshilong's avatar
jshilong committed
72
73
        stack_points = torch.stack(points)
        x = self.extract_feat(stack_points)
74
75
76
77
78
79
80
81
82
83
84
85
86
87

        # features for rcnn
        backbone_feats = x['fp_features'].clone()
        backbone_xyz = x['fp_xyz'].clone()
        rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz}

        bbox_preds, cls_preds = self.rpn_head(x)

        rpn_loss = self.rpn_head.loss(
            bbox_preds=bbox_preds,
            cls_preds=cls_preds,
            points=points,
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
jshilong's avatar
jshilong committed
88
            input_metas=input_metas)
89
90
        losses.update(rpn_loss)

jshilong's avatar
jshilong committed
91
92
        bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds,
                                             cls_preds, input_metas)
93
94
95
96
97
98
99
100
101
102
        proposal_list = [
            dict(
                boxes_3d=bboxes,
                scores_3d=scores,
                labels_3d=labels,
                cls_preds=preds_cls)
            for bboxes, scores, labels, preds_cls in bbox_list
        ]
        rcnn_feats.update({'points_cls_preds': cls_preds})

jshilong's avatar
jshilong committed
103
        roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                                                 proposal_list, gt_bboxes_3d,
                                                 gt_labels_3d)
        losses.update(roi_losses)

        return losses

    def simple_test(self, points, img_metas, imgs=None, rescale=False):
        """Forward of testing.

        Args:
            points (list[torch.Tensor]): Points of each sample.
            img_metas (list[dict]): Image metas.
            imgs (list[torch.Tensor], optional): Images of each sample.
                Defaults to None.
            rescale (bool, optional): Whether to rescale results.
                Defaults to False.

        Returns:
            list: Predicted 3d boxes.
        """
jshilong's avatar
jshilong committed
124
        stack_points = torch.stack(points)
125

jshilong's avatar
jshilong committed
126
        x = self.extract_feat(stack_points)
127
128
129
130
131
132
133
134
        # features for rcnn
        backbone_feats = x['fp_features'].clone()
        backbone_xyz = x['fp_xyz'].clone()
        rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz}
        bbox_preds, cls_preds = self.rpn_head(x)
        rcnn_feats.update({'points_cls_preds': cls_preds})

        bbox_list = self.rpn_head.get_bboxes(
jshilong's avatar
jshilong committed
135
            stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale)
136
137
138
139
140
141
142
143
144
145
146
147
148

        proposal_list = [
            dict(
                boxes_3d=bboxes,
                scores_3d=scores,
                labels_3d=labels,
                cls_preds=preds_cls)
            for bboxes, scores, labels, preds_cls in bbox_list
        ]
        bbox_results = self.roi_head.simple_test(rcnn_feats, img_metas,
                                                 proposal_list)

        return bbox_results