h3d_roi_head.py 5.48 KB
Newer Older
encore-zhou's avatar
encore-zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from mmdet3d.core.bbox import bbox3d2result
from mmdet.models import HEADS
from ..builder import build_head
from .base_3droi_head import Base3DRoIHead


@HEADS.register_module()
class H3DRoIHead(Base3DRoIHead):
    """H3D roi head for H3DNet.

    Args:
        primitive_list (List): Configs of primitive heads.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """

    def __init__(self,
                 primitive_list,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None):
        super(H3DRoIHead, self).__init__(
            bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
        # Primitive module
        assert len(primitive_list) == 3
        self.primitive_z = build_head(primitive_list[0])
        self.primitive_xy = build_head(primitive_list[1])
        self.primitive_line = build_head(primitive_list[2])

    def init_weights(self, pretrained):
        """Initialize weights, skip since ``H3DROIHead`` does not need to
        initialize weights."""
        pass

    def init_mask_head(self):
        """Initialize mask head, skip since ``H3DROIHead`` does not have
        one."""
        pass

    def init_bbox_head(self, bbox_head):
        """Initialize box head."""
        bbox_head['train_cfg'] = self.train_cfg
        bbox_head['test_cfg'] = self.test_cfg
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
        """Initialize assigner and sampler."""
        pass

    def forward_train(self,
                      feats_dict,
                      img_metas,
                      points,
                      gt_bboxes_3d,
                      gt_labels_3d,
                      pts_semantic_mask,
                      pts_instance_mask,
                      gt_bboxes_ignore=None):
        """Training forward function of PartAggregationROIHead.

        Args:
            feats_dict (dict): Contains features from the first stage.
            img_metas (list[dict]): Contain pcd and img's meta info.
            points (list[torch.Tensor]): Input points.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
            pts_semantic_mask (None | list[torch.Tensor]): Point-wise
                semantic mask.
            pts_instance_mask (None | list[torch.Tensor]): Point-wise
                instance mask.
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.

        Returns:
            dict: losses from each head.
        """
        losses = dict()

        sample_mod = self.train_cfg.sample_mod
        assert sample_mod in ['vote', 'seed', 'random']
        result_z = self.primitive_z(feats_dict, sample_mod)
        feats_dict.update(result_z)

        result_xy = self.primitive_xy(feats_dict, sample_mod)
        feats_dict.update(result_xy)

        result_line = self.primitive_line(feats_dict, sample_mod)
        feats_dict.update(result_line)

        primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
                                 gt_labels_3d, pts_semantic_mask,
                                 pts_instance_mask, img_metas,
                                 gt_bboxes_ignore)

        loss_z = self.primitive_z.loss(*primitive_loss_inputs)
        losses.update(loss_z)

        loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
        losses.update(loss_xy)

        loss_line = self.primitive_line.loss(*primitive_loss_inputs)
        losses.update(loss_line)

        targets = feats_dict.pop('targets')

        bbox_results = self.bbox_head(feats_dict, sample_mod)

        feats_dict.update(bbox_results)
        bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d,
                                        gt_labels_3d, pts_semantic_mask,
                                        pts_instance_mask, img_metas, targets,
                                        gt_bboxes_ignore)
        losses.update(bbox_loss)

        return losses

    def simple_test(self, feats_dict, img_metas, points, rescale=False):
        """Simple testing forward function of PartAggregationROIHead.

        Note:
            This function assumes that the batch size is 1

        Args:
            feats_dict (dict): Contains features from the first stage.
            img_metas (list[dict]): Contain pcd and img's meta info.
            points (torch.Tensor): Input points.
            rescale (bool): Whether to rescale results.

        Returns:
            dict: Bbox results of one frame.
        """
134
        sample_mod = self.test_cfg.sample_mod
encore-zhou's avatar
encore-zhou committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        assert sample_mod in ['vote', 'seed', 'random']

        result_z = self.primitive_z(feats_dict, sample_mod)
        feats_dict.update(result_z)

        result_xy = self.primitive_xy(feats_dict, sample_mod)
        feats_dict.update(result_xy)

        result_line = self.primitive_line(feats_dict, sample_mod)
        feats_dict.update(result_line)

        bbox_preds = self.bbox_head(feats_dict, sample_mod)
        feats_dict.update(bbox_preds)
        bbox_list = self.bbox_head.get_bboxes(
            points,
            feats_dict,
            img_metas,
            rescale=rescale,
            suffix='_optimized')
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]