base.py 4.15 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from mmdet3d.core import Det3DDataSample
3
4
from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType,
                                OptMultiConfig, OptSampleList, SampleList)
5
from mmdet3d.registry import MODELS
6
from mmdet.models import BaseDetector
zhangwenwei's avatar
zhangwenwei committed
7
8


9
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class Base3DDetector(BaseDetector):
11
    """Base class for 3D detectors.
zhangwenwei's avatar
zhangwenwei committed
12

13
    Args:
14
15
16
17
18
       data_preprocessor (dict or ConfigDict, optional): The pre-process
           config of :class:`BaseDataPreprocessor`.  it usually includes,
            ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
       init_cfg (dict or ConfigDict, optional): the config to control the
           initialization. Defaults to None.
19
20
21
    """

    def __init__(self,
22
23
24
                 data_processor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(data_preprocessor=data_processor, init_cfg=init_cfg)
25

26
27
28
29
30
31
    def forward(self,
                batch_inputs_dict: dict,
                batch_data_samples: OptSampleList = None,
                mode: str = 'tensor',
                **kwargs) -> ForwardResults:
        """The unified entry for a forward process in both training and test.
32

33
        The method should accept three modes: "tensor", "predict" and "loss":
34

35
36
37
38
39
40
        - "tensor": Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
        processed to a list of :obj:`DetDataSample`.
        - "loss": Forward and return a dict of losses according to the given
        inputs and data samples.
41

42
43
        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.
44
45

        Args:
46
47
48
49
50
            batch_inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            batch_data_samples (list[:obj:`DetDataSample`], optional): The
                annotation data of every samples. Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.
51
52

        Returns:
53
            The return type depends on ``mode``.
54

55
56
57
58
59
60
61
62
63
64
65
66
            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of :obj:`DetDataSample`.
            - If ``mode="loss"``, return a dict of tensor.
        """
        if mode == 'loss':
            return self.loss(batch_inputs_dict, batch_data_samples, **kwargs)
        elif mode == 'predict':
            return self.predict(batch_inputs_dict, batch_data_samples,
                                **kwargs)
        elif mode == 'tensor':
            return self._forward(batch_inputs_dict, batch_data_samples,
                                 **kwargs)
67
        else:
68
69
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')
70

71
72
    def convert_to_datasample(self, results_list: InstanceList) -> SampleList:
        """Convert results list to `Det3DDataSample`.
73

74
75
        Subclasses could override it to be compatible for some multi-modality
        3D detectors.
76
77
78
79
80
81
82

        Args:
            results_list (list[:obj:`InstanceData`]): Detection results of
                each sample.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
83
84
            input. Each Det3DDataSample usually contains
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
85
86
87
            contains following keys.

                - scores_3d (Tensor): Classification scores, has a shape
88
89
                    (num_instance, )
                - labels_3d (Tensor): Labels of 3D bboxes, has a shape
90
                    (num_instances, ).
91
92
93
94
                - bboxes_3d (Tensor): Contains a tensor with shape
                    (num_instances, C) where C >=7.
        """
        out_results_list = []
95
96
97
        for i in range(len(results_list)):
            result = Det3DDataSample()
            result.pred_instances_3d = results_list[i]
98
99
            out_results_list.append(result)
        return out_results_list