h3dnet.py 5.67 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
3
from typing import Dict, List, Optional, Union

encore-zhou's avatar
encore-zhou committed
4
import torch
jshilong's avatar
jshilong committed
5
from torch import Tensor
encore-zhou's avatar
encore-zhou committed
6

7
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.structures import Det3DDataSample
encore-zhou's avatar
encore-zhou committed
9
10
11
from .two_stage import TwoStage3DDetector


12
@MODELS.register_module()
encore-zhou's avatar
encore-zhou committed
13
14
15
16
class H3DNet(TwoStage3DDetector):
    r"""H3DNet model.

    Please refer to the `paper <https://arxiv.org/abs/2006.05682>`_
jshilong's avatar
jshilong committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

    Args:
        backbone (dict): Config dict of detector's backbone.
        neck (dict, optional): Config dict of neck. Defaults to None.
        rpn_head (dict, optional): Config dict of rpn head. Defaults to None.
        roi_head (dict, optional): Config dict of roi head. Defaults to None.
        train_cfg (dict, optional): Config dict of training hyper-parameters.
            Defaults to None.
        test_cfg (dict, optional): Config dict of test hyper-parameters.
            Defaults to None.
        init_cfg (dict, optional): the config to control the
           initialization. Default to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
            ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
encore-zhou's avatar
encore-zhou committed
32
33
34
    """

    def __init__(self,
jshilong's avatar
jshilong committed
35
36
37
38
39
40
41
42
43
                 backbone: dict,
                 neck: Optional[dict] = None,
                 rpn_head: Optional[dict] = None,
                 roi_head: Optional[dict] = None,
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 init_cfg: Optional[dict] = None,
                 data_preprocessor: Optional[dict] = None,
                 **kwargs) -> None:
encore-zhou's avatar
encore-zhou committed
44
45
46
47
48
49
50
        super(H3DNet, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
jshilong's avatar
jshilong committed
51
52
53
54
55
56
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor,
            **kwargs)

    def extract_feat(self, batch_inputs_dict: dict) -> None:
        """Directly extract features from the backbone+neck.
encore-zhou's avatar
encore-zhou committed
57
58

        Args:
jshilong's avatar
jshilong committed
59
60
61
62
63

            batch_inputs_dict (dict): The model input dict which include
                'points'.

                - points (list[torch.Tensor]): Point cloud of each sample.
encore-zhou's avatar
encore-zhou committed
64
65

        Returns:
jshilong's avatar
jshilong committed
66
            dict: Dict of feature.
encore-zhou's avatar
encore-zhou committed
67
        """
jshilong's avatar
jshilong committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        stack_points = torch.stack(batch_inputs_dict['points'])
        x = self.backbone(stack_points)
        if self.with_neck:
            x = self.neck(x)
        return x

    def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
             batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
        """
        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        feats_dict = self.extract_feat(batch_inputs_dict)
encore-zhou's avatar
encore-zhou committed
91
92
93
94
95
96
97
98
99

        feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
        feats_dict['fp_features'] = [feats_dict['hd_feature']]
        feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]

        losses = dict()
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)
jshilong's avatar
jshilong committed
100
101
102
103
104
105
106
107
108
109
            # note, the feats_dict would be added new key & value in rpn_head
            rpn_losses, rpn_proposals = self.rpn_head.loss_and_predict(
                batch_inputs_dict['points'],
                feats_dict,
                batch_data_samples,
                ret_target=True,
                proposal_cfg=proposal_cfg)
            feats_dict['targets'] = rpn_losses.pop('targets')
            losses.update(rpn_losses)
            feats_dict['rpn_proposals'] = rpn_proposals
encore-zhou's avatar
encore-zhou committed
110
111
112
        else:
            raise NotImplementedError

jshilong's avatar
jshilong committed
113
114
115
        roi_losses = self.roi_head.loss(batch_inputs_dict['points'],
                                        feats_dict, batch_data_samples,
                                        **kwargs)
encore-zhou's avatar
encore-zhou committed
116
117
118
119
        losses.update(roi_losses)

        return losses

jshilong's avatar
jshilong committed
120
121
122
123
124
    def predict(
            self, batch_input_dict: Dict,
            batch_data_samples: List[Det3DDataSample]
    ) -> List[Det3DDataSample]:
        """Get model predictions.
encore-zhou's avatar
encore-zhou committed
125
126
127

        Args:
            points (list[torch.Tensor]): Points of each sample.
jshilong's avatar
jshilong committed
128
129
130
            batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
                contains the meta information of each sample and
                corresponding annotations.
encore-zhou's avatar
encore-zhou committed
131
132
133
134
135

        Returns:
            list: Predicted 3d boxes.
        """

jshilong's avatar
jshilong committed
136
        feats_dict = self.extract_feat(batch_input_dict)
encore-zhou's avatar
encore-zhou committed
137
138
139
140
141
142
        feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]]
        feats_dict['fp_features'] = [feats_dict['hd_feature']]
        feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]]

        if self.with_rpn:
            proposal_cfg = self.test_cfg.rpn
jshilong's avatar
jshilong committed
143
144
145
146
147
148
            rpn_proposals = self.rpn_head.predict(
                batch_input_dict['points'],
                feats_dict,
                batch_data_samples,
                use_nms=proposal_cfg.use_nms)
            feats_dict['rpn_proposals'] = rpn_proposals
encore-zhou's avatar
encore-zhou committed
149
150
151
        else:
            raise NotImplementedError

jshilong's avatar
jshilong committed
152
153
154
155
156
157
        results_list = self.roi_head.predict(
            batch_input_dict['points'],
            feats_dict,
            batch_data_samples,
            suffix='_optimized')
        return self.convert_to_datasample(results_list)