"...training_service/common/clusterJobRestServer.ts" did not exist on "8c07cf4118a55346265a321d90059dff76436847"
h3d_roi_head.py 5.49 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
encore-zhou's avatar
encore-zhou committed
2
from mmdet3d.core.bbox import bbox3d2result
3
from ..builder import HEADS, build_head
encore-zhou's avatar
encore-zhou committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from .base_3droi_head import Base3DRoIHead


@HEADS.register_module()
class H3DRoIHead(Base3DRoIHead):
    """H3D roi head for H3DNet.

    Args:
        primitive_list (List): Configs of primitive heads.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """

    def __init__(self,
                 primitive_list,
                 bbox_head=None,
                 train_cfg=None,
22
23
24
                 test_cfg=None,
                 pretrained=None,
                 init_cfg=None):
encore-zhou's avatar
encore-zhou committed
25
        super(H3DRoIHead, self).__init__(
26
27
28
29
30
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained,
            init_cfg=init_cfg)
encore-zhou's avatar
encore-zhou committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        # Primitive module
        assert len(primitive_list) == 3
        self.primitive_z = build_head(primitive_list[0])
        self.primitive_xy = build_head(primitive_list[1])
        self.primitive_line = build_head(primitive_list[2])

    def init_mask_head(self):
        """Initialize mask head, skip since ``H3DROIHead`` does not have
        one."""
        pass

    def init_bbox_head(self, bbox_head):
        """Initialize box head."""
        bbox_head['train_cfg'] = self.train_cfg
        bbox_head['test_cfg'] = self.test_cfg
        self.bbox_head = build_head(bbox_head)

    def init_assigner_sampler(self):
        """Initialize assigner and sampler."""
        pass

    def forward_train(self,
                      feats_dict,
                      img_metas,
                      points,
                      gt_bboxes_3d,
                      gt_labels_3d,
                      pts_semantic_mask,
                      pts_instance_mask,
                      gt_bboxes_ignore=None):
        """Training forward function of PartAggregationROIHead.

        Args:
            feats_dict (dict): Contains features from the first stage.
            img_metas (list[dict]): Contain pcd and img's meta info.
            points (list[torch.Tensor]): Input points.
67
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
encore-zhou's avatar
encore-zhou committed
68
69
                bboxes of each sample.
            gt_labels_3d (list[torch.Tensor]): Labels of each sample.
70
            pts_semantic_mask (list[torch.Tensor]): Point-wise
encore-zhou's avatar
encore-zhou committed
71
                semantic mask.
72
            pts_instance_mask (list[torch.Tensor]): Point-wise
encore-zhou's avatar
encore-zhou committed
73
                instance mask.
74
75
            gt_bboxes_ignore (list[torch.Tensor]): Specify
                which bounding boxes to ignore.
encore-zhou's avatar
encore-zhou committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

        Returns:
            dict: losses from each head.
        """
        losses = dict()

        sample_mod = self.train_cfg.sample_mod
        assert sample_mod in ['vote', 'seed', 'random']
        result_z = self.primitive_z(feats_dict, sample_mod)
        feats_dict.update(result_z)

        result_xy = self.primitive_xy(feats_dict, sample_mod)
        feats_dict.update(result_xy)

        result_line = self.primitive_line(feats_dict, sample_mod)
        feats_dict.update(result_line)

        primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d,
                                 gt_labels_3d, pts_semantic_mask,
                                 pts_instance_mask, img_metas,
                                 gt_bboxes_ignore)

        loss_z = self.primitive_z.loss(*primitive_loss_inputs)
        losses.update(loss_z)

        loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
        losses.update(loss_xy)

        loss_line = self.primitive_line.loss(*primitive_loss_inputs)
        losses.update(loss_line)

        targets = feats_dict.pop('targets')

        bbox_results = self.bbox_head(feats_dict, sample_mod)

        feats_dict.update(bbox_results)
        bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d,
                                        gt_labels_3d, pts_semantic_mask,
                                        pts_instance_mask, img_metas, targets,
                                        gt_bboxes_ignore)
        losses.update(bbox_loss)

        return losses

    def simple_test(self, feats_dict, img_metas, points, rescale=False):
        """Simple testing forward function of PartAggregationROIHead.

        Note:
            This function assumes that the batch size is 1

        Args:
            feats_dict (dict): Contains features from the first stage.
            img_metas (list[dict]): Contain pcd and img's meta info.
            points (torch.Tensor): Input points.
            rescale (bool): Whether to rescale results.

        Returns:
            dict: Bbox results of one frame.
        """
135
        sample_mod = self.test_cfg.sample_mod
encore-zhou's avatar
encore-zhou committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        assert sample_mod in ['vote', 'seed', 'random']

        result_z = self.primitive_z(feats_dict, sample_mod)
        feats_dict.update(result_z)

        result_xy = self.primitive_xy(feats_dict, sample_mod)
        feats_dict.update(result_xy)

        result_line = self.primitive_line(feats_dict, sample_mod)
        feats_dict.update(result_line)

        bbox_preds = self.bbox_head(feats_dict, sample_mod)
        feats_dict.update(bbox_preds)
        bbox_list = self.bbox_head.get_bboxes(
            points,
            feats_dict,
            img_metas,
            rescale=rescale,
            suffix='_optimized')
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
159
        return bbox_results