votenet.py 6.02 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
from typing import Dict, List, Optional, Union
wuyuefeng's avatar
Votenet  
wuyuefeng committed
3

jshilong's avatar
jshilong committed
4
5
6
7
from mmengine import InstanceData
from torch import Tensor

from mmdet3d.core import Det3DDataSample, merge_aug_bboxes_3d
8
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
9
from .single_stage import SingleStage3DDetector
wuyuefeng's avatar
Votenet  
wuyuefeng committed
10
11


12
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
13
class VoteNet(SingleStage3DDetector):
jshilong's avatar
jshilong committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.

    Args:
        backbone (dict): Config dict of detector's backbone.
        bbox_head (dict, optional): Config dict of box head. Defaults to None.
        train_cfg (dict, optional): Config dict of training hyper-parameters.
            Defaults to None.
        test_cfg (dict, optional): Config dict of test hyper-parameters.
            Defaults to None.
        init_cfg (dict, optional): the config to control the
           initialization. Default to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
            ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
    """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
29
30

    def __init__(self,
jshilong's avatar
jshilong committed
31
32
33
34
35
36
37
                 backbone: dict,
                 bbox_head: Optional[dict] = None,
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 init_cfg: Optional[dict] = None,
                 data_preprocessor: Optional[dict] = None,
                 **kwargs):
wuyuefeng's avatar
Votenet  
wuyuefeng committed
38
39
40
41
42
        super(VoteNet, self).__init__(
            backbone=backbone,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
jshilong's avatar
jshilong committed
43
44
45
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor,
            **kwargs)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
46

jshilong's avatar
jshilong committed
47
48
49
50
    def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
        """
wuyuefeng's avatar
Votenet  
wuyuefeng committed
51
        Args:
jshilong's avatar
jshilong committed
52
53
54
55
56
57
58
59
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
60
61

        Returns:
jshilong's avatar
jshilong committed
62
            dict[str, Tensor]: A dictionary of loss components.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
63
        """
jshilong's avatar
jshilong committed
64
65
66
67
        feat_dict = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        losses = self.bbox_head.loss(points, feat_dict, batch_data_samples,
                                     **kwargs)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
68
69
        return losses

jshilong's avatar
jshilong committed
70
71
72
    def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
                batch_data_samples: List[Det3DDataSample],
                **kwargs) -> List[Det3DDataSample]:
wuyuefeng's avatar
Votenet  
wuyuefeng committed
73
74
75
        """Forward of testing.

        Args:
jshilong's avatar
jshilong committed
76
77
78
79
80
81
82
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
83
84

        Returns:
jshilong's avatar
jshilong committed
85
86
87
88
89
90
91
92
93
94
95
            list[:obj:`Det3DDataSample`]: Detection results of the
            input sample. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
                    (num_instances, )
                - labels_3d (Tensor): Labels of bboxes, has a shape
                    (num_instances, ).
                - bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
                    contains a tensor with shape (num_instances, 7).
wuyuefeng's avatar
Votenet  
wuyuefeng committed
96
        """
jshilong's avatar
jshilong committed
97
98
99
100
101
102
        feats_dict = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        results_list = self.bbox_head.predict(points, feats_dict,
                                              batch_data_samples, **kwargs)
        data_3d_samples = self.convert_to_datasample(results_list)
        return data_3d_samples
zhangwenwei's avatar
zhangwenwei committed
103

jshilong's avatar
jshilong committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    def aug_test(self, aug_inputs_list: List[dict],
                 aug_data_samples: List[List[dict]], **kwargs):
        """Test with augmentation.

        Batch size always is 1 when do the augtest.

        Args:
            aug_inputs_list (List[dict]): The list indicate same data
                under differecnt augmentation.
            aug_data_samples (List[List[dict]]): The outer list
                indicate different augmentation, and the inter
                list indicate the batch size.
        """
        num_augs = len(aug_inputs_list)
        if num_augs == 1:
            return self.predict(aug_inputs_list[0], aug_data_samples[0])

        batch_size = len(aug_data_samples[0])
        assert batch_size == 1
        multi_aug_results = []
        for aug_id in range(num_augs):
            batch_inputs_dict = aug_inputs_list[aug_id]
            batch_data_samples = aug_data_samples[aug_id]
            feats_dict = self.extract_feat(batch_inputs_dict)
            points = batch_inputs_dict['points']
            results_list = self.bbox_head.predict(points, feats_dict,
                                                  batch_data_samples, **kwargs)
            multi_aug_results.append(results_list[0])
        aug_input_metas_list = []
        for aug_index in range(num_augs):
            metainfo = aug_data_samples[aug_id][0].metainfo
            aug_input_metas_list.append(metainfo)

        aug_results_list = [item.to_dict() for item in multi_aug_results]
zhangwenwei's avatar
zhangwenwei committed
138
        # after merging, bboxes will be rescaled to the original image size
jshilong's avatar
jshilong committed
139
140
141
        merged_results_dict = merge_aug_bboxes_3d(aug_results_list,
                                                  aug_input_metas_list,
                                                  self.bbox_head.test_cfg)
zhangwenwei's avatar
zhangwenwei committed
142

jshilong's avatar
jshilong committed
143
144
145
        merged_results = InstanceData(**merged_results_dict)
        data_3d_samples = self.convert_to_datasample([merged_results])
        return data_3d_samples