votenet.py 6.21 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
from typing import Dict, List, Optional, Union
wuyuefeng's avatar
Votenet  
wuyuefeng committed
3

4
from mmengine.structures import InstanceData
jshilong's avatar
jshilong committed
5
6
from torch import Tensor

7
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.structures import Det3DDataSample
from ..test_time_augs import merge_aug_bboxes_3d
zhangwenwei's avatar
zhangwenwei committed
10
from .single_stage import SingleStage3DDetector
wuyuefeng's avatar
Votenet  
wuyuefeng committed
11
12


13
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
14
class VoteNet(SingleStage3DDetector):
jshilong's avatar
jshilong committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.

    Args:
        backbone (dict): Config dict of detector's backbone.
        bbox_head (dict, optional): Config dict of box head. Defaults to None.
        train_cfg (dict, optional): Config dict of training hyper-parameters.
            Defaults to None.
        test_cfg (dict, optional): Config dict of test hyper-parameters.
            Defaults to None.
        init_cfg (dict, optional): the config to control the
           initialization. Default to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
            ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
    """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
30
31

    def __init__(self,
jshilong's avatar
jshilong committed
32
33
34
35
36
37
38
                 backbone: dict,
                 bbox_head: Optional[dict] = None,
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 init_cfg: Optional[dict] = None,
                 data_preprocessor: Optional[dict] = None,
                 **kwargs):
wuyuefeng's avatar
Votenet  
wuyuefeng committed
39
40
41
42
43
        super(VoteNet, self).__init__(
            backbone=backbone,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
jshilong's avatar
jshilong committed
44
45
46
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor,
            **kwargs)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
47

jshilong's avatar
jshilong committed
48
49
50
51
    def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
52
        Args:
jshilong's avatar
jshilong committed
53
54
55
56
57
58
59
60
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
61
62

        Returns:
jshilong's avatar
jshilong committed
63
            dict[str, Tensor]: A dictionary of loss components.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
64
        """
jshilong's avatar
jshilong committed
65
66
67
68
        feat_dict = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        losses = self.bbox_head.loss(points, feat_dict, batch_data_samples,
                                     **kwargs)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
69
70
        return losses

jshilong's avatar
jshilong committed
71
72
73
    def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
                batch_data_samples: List[Det3DDataSample],
                **kwargs) -> List[Det3DDataSample]:
wuyuefeng's avatar
Votenet  
wuyuefeng committed
74
75
76
        """Forward of testing.

        Args:
jshilong's avatar
jshilong committed
77
78
79
80
81
82
83
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
84
85

        Returns:
jshilong's avatar
jshilong committed
86
87
88
89
90
91
92
93
94
95
96
            list[:obj:`Det3DDataSample`]: Detection results of the
            input sample. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instances, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                    contains a tensor with shape (num_instances, 7).
wuyuefeng's avatar
Votenet  
wuyuefeng committed
97
        """
jshilong's avatar
jshilong committed
98
99
100
101
        feats_dict = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        results_list = self.bbox_head.predict(points, feats_dict,
                                              batch_data_samples, **kwargs)
102
103
        data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
                                                      results_list)
jshilong's avatar
jshilong committed
104
        return data_3d_samples
zhangwenwei's avatar
zhangwenwei committed
105

jshilong's avatar
jshilong committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    def aug_test(self, aug_inputs_list: List[dict],
                 aug_data_samples: List[List[dict]], **kwargs):
        """Test with augmentation.

        Batch size always is 1 when do the augtest.

        Args:
            aug_inputs_list (List[dict]): The list indicate same data
                under differecnt augmentation.
            aug_data_samples (List[List[dict]]): The outer list
                indicate different augmentation, and the inter
                list indicate the batch size.
        """
        num_augs = len(aug_inputs_list)
        if num_augs == 1:
            return self.predict(aug_inputs_list[0], aug_data_samples[0])

        batch_size = len(aug_data_samples[0])
        assert batch_size == 1
        multi_aug_results = []
        for aug_id in range(num_augs):
            batch_inputs_dict = aug_inputs_list[aug_id]
            batch_data_samples = aug_data_samples[aug_id]
            feats_dict = self.extract_feat(batch_inputs_dict)
            points = batch_inputs_dict['points']
            results_list = self.bbox_head.predict(points, feats_dict,
                                                  batch_data_samples, **kwargs)
            multi_aug_results.append(results_list[0])
        aug_input_metas_list = []
        for aug_index in range(num_augs):
            metainfo = aug_data_samples[aug_id][0].metainfo
            aug_input_metas_list.append(metainfo)

        aug_results_list = [item.to_dict() for item in multi_aug_results]
zhangwenwei's avatar
zhangwenwei committed
140
        # after merging, bboxes will be rescaled to the original image size
jshilong's avatar
jshilong committed
141
142
143
        merged_results_dict = merge_aug_bboxes_3d(aug_results_list,
                                                  aug_input_metas_list,
                                                  self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
144

jshilong's avatar
jshilong committed
145
        merged_results = InstanceData(**merged_results_dict)
146
147
        data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
                                                      [merged_results])
jshilong's avatar
jshilong committed
148
        return data_3d_samples