base.py 6.29 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
from typing import List, Optional, Union
jshilong's avatar
jshilong committed
3

jshilong's avatar
jshilong committed
4
5
from mmengine import InstanceData

6
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
7
8
9
10
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.det3d_data_sample import (ForwardResults,
                                                  OptSampleList, SampleList)
from mmdet3d.utils.typing import InstanceList, OptConfigType, OptMultiConfig
11
from mmdet.models import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
12
13


14
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
15
class Base3DDetector(BaseDetector):
16
    """Base class for 3D detectors.
zhangwenwei's avatar
zhangwenwei committed
17

18
    Args:
19
20
21
22
23
       data_preprocessor (dict or ConfigDict, optional): The pre-process
           config of :class:`BaseDataPreprocessor`.  it usually includes,
            ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
       init_cfg (dict or ConfigDict, optional): the config to control the
           initialization. Defaults to None.
24
25
26
    """

    def __init__(self,
27
                 data_preprocessor: OptConfigType = None,
28
                 init_cfg: OptMultiConfig = None) -> None:
29
30
        super().__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)
31

32
    def forward(self,
jshilong's avatar
jshilong committed
33
34
                inputs: Union[dict, List[dict]],
                data_samples: OptSampleList = None,
35
36
37
                mode: str = 'tensor',
                **kwargs) -> ForwardResults:
        """The unified entry for a forward process in both training and test.
38

39
        The method should accept three modes: "tensor", "predict" and "loss":
40

41
42
43
        - "tensor": Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
jshilong's avatar
jshilong committed
44
        processed to a list of :obj:`Det3DDataSample`.
45
46
        - "loss": Forward and return a dict of losses according to the given
        inputs and data samples.
47

48
49
        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.
50
51

        Args:
jshilong's avatar
jshilong committed
52
53
54
55
56
57
58
            inputs  (dict | list[dict]): When it is a list[dict], the
                outer list indicate the test time augmentation. Each
                dict contains batch inputs
                which include 'points' and 'imgs' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
jshilong's avatar
jshilong committed
59
60
            data_samples (list[:obj:`Det3DDataSample`],
                list[list[:obj:`Det3DDataSample`]], optional): The
jshilong's avatar
jshilong committed
61
62
63
64
                annotation data of every samples. When it is a list[list], the
                outer list indicate the test time augmentation, and the
                inter list indicate the batch. Otherwise, the list simply
                indicate the batch. Defaults to None.
65
            mode (str): Return what kind of value. Defaults to 'tensor'.
66
67

        Returns:
68
            The return type depends on ``mode``.
69

70
            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
jshilong's avatar
jshilong committed
71
            - If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
72
73
74
            - If ``mode="loss"``, return a dict of tensor.
        """
        if mode == 'loss':
jshilong's avatar
jshilong committed
75
            return self.loss(inputs, data_samples, **kwargs)
76
        elif mode == 'predict':
jshilong's avatar
jshilong committed
77
78
79
80
81
82
83
84
85
86
            if isinstance(data_samples[0], list):
                # aug test
                assert len(data_samples[0]) == 1, 'Only support ' \
                                                  'batch_size 1 ' \
                                                  'in mmdet3d when ' \
                                                  'do the test' \
                                                  'time augmentation.'
                return self.aug_test(inputs, data_samples, **kwargs)
            else:
                return self.predict(inputs, data_samples, **kwargs)
87
        elif mode == 'tensor':
jshilong's avatar
jshilong committed
88
            return self._forward(inputs, data_samples, **kwargs)
89
        else:
90
91
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')
92

jshilong's avatar
jshilong committed
93
94
    def convert_to_datasample(
        self,
zhangshilong's avatar
zhangshilong committed
95
96
        results_list_3d: Optional[InstanceList] = None,
        results_list_2d: Optional[InstanceList] = None,
jshilong's avatar
jshilong committed
97
    ) -> SampleList:
98
        """Convert results list to `Det3DDataSample`.
99

100
101
        Subclasses could override it to be compatible for some multi-modality
        3D detectors.
102
103
104
105
106
107
108

        Args:
            results_list (list[:obj:`InstanceData`]): Detection results of
                each sample.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
109
            input. Each Det3DDataSample usually contains
jshilong's avatar
jshilong committed
110
111
112
113
114
115
116
            'pred_instances_3d'. And the ``pred_instances_3d`` normally
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels_3d (Tensor): Labels of 3D bboxes, has a shape
              (num_instances, ).
zhangshilong's avatar
zhangshilong committed
117
            - bbox_3d (Tensor): Contains a tensor with shape
jshilong's avatar
jshilong committed
118
119
120
              (num_instances, C) where C >=7.
            When there are image prediction in some models, it should
            contains  `pred_instances`, And the ``pred_instances`` normally
121
122
            contains following keys.

jshilong's avatar
jshilong committed
123
124
125
126
127
128
            - scores (Tensor): Classification scores of image, has a shape
              (num_instance, )
            - labels (Tensor): Predict Labels of 2D bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Contains a tensor with shape
              (num_instances, 4).
129
        """
jshilong's avatar
jshilong committed
130
131

        data_sample_list = []
zhangshilong's avatar
zhangshilong committed
132
133
134
135
        assert (results_list_2d is not None) or \
               (results_list_3d is not None),\
               'please pass at least one type of results_list'

jshilong's avatar
jshilong committed
136
137
138
139
        if results_list_2d is None:
            results_list_2d = [
                InstanceData() for _ in range(len(results_list_3d))
            ]
zhangshilong's avatar
zhangshilong committed
140
141
142
143
        if results_list_3d is None:
            results_list_3d = [
                InstanceData() for _ in range(len(results_list_2d))
            ]
jshilong's avatar
jshilong committed
144
        for i in range(len(results_list_3d)):
145
            result = Det3DDataSample()
jshilong's avatar
jshilong committed
146
147
148
149
            result.pred_instances_3d = results_list_3d[i]
            result.pred_instances = results_list_2d[i]
            data_sample_list.append(result)
        return data_sample_list