base.py 6.27 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from abc import ABCMeta, abstractmethod
3
from typing import Dict, List, Union
4

5
6
from mmengine.model import BaseModel
from torch import Tensor
7

8
from mmdet3d.structures import PointData
zhangshilong's avatar
zhangshilong committed
9
10
11
from mmdet3d.structures.det3d_data_sample import (ForwardResults,
                                                  OptSampleList, SampleList)
from mmdet3d.utils import OptConfigType, OptMultiConfig
12
13


14
class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
15
16
    """Base class for 3D segmentors.

17
    Args:
18
19
20
21
22
23
        data_preprocessor (dict or ConfigDict, optional): Model preprocessing
            config for processing the input data. it usually includes
            ``to_rgb``, ``pad_size_divisor``, ``pad_val``, ``mean`` and
            ``std``. Defaults to None.
       init_cfg (dict or ConfigDict, optional): The config to control the
           initialization. Defaults to None.
24
25
    """

26
27
28
29
30
31
32
33
    def __init__(self,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None):
        super(Base3DSegmentor, self).__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)

    @property
    def with_neck(self) -> bool:
34
        """bool: Whether the segmentor has neck."""
35
36
37
38
        return hasattr(self, 'neck') and self.neck is not None

    @property
    def with_auxiliary_head(self) -> bool:
39
        """bool: Whether the segmentor has auxiliary head."""
40
41
42
        return hasattr(self,
                       'auxiliary_head') and self.auxiliary_head is not None

43
    @property
44
    def with_decode_head(self) -> bool:
45
        """bool: Whether the segmentor has decode head."""
46
47
48
49
        return hasattr(self, 'decode_head') and self.decode_head is not None

    @property
    def with_regularization_loss(self) -> bool:
50
        """bool: Whether the segmentor has regularization loss for weight."""
51
52
53
        return hasattr(self, 'loss_regularization') and \
            self.loss_regularization is not None

54
    @abstractmethod
55
    def extract_feat(self, batch_inputs: Tensor) -> dict:
56
57
58
59
60
        """Placeholder for extract features from images."""
        pass

    @abstractmethod
    def encode_decode(self, batch_inputs: Tensor,
61
                      batch_data_samples: SampleList) -> Tensor:
62
63
64
65
66
        """Placeholder for encode images with backbone and decode into a
        semantic segmentation map of the same size as input."""
        pass

    def forward(self,
67
68
                inputs: Union[dict, List[dict]],
                data_samples: OptSampleList = None,
69
70
71
72
73
74
                mode: str = 'tensor') -> ForwardResults:
        """The unified entry for a forward process in both training and test.

        The method should accept three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor or tuple of
75
          tensor without any post-processing, same as a common nn.Module.
76
        - "predict": Forward and return the predictions, which are fully
77
          processed to a list of :obj:`SegDataSample`.
78
        - "loss": Forward and return a dict of losses according to the given
79
          inputs and data samples.
80
81
82

        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.
83
84

        Args:
85
86
            inputs (dict or List[dict]): Input sample dict which includes
                'points' and 'imgs' keys.
87

88
89
90
                - points (List[Tensor]): Point cloud of each sample.
                - imgs (Tensor): Image tensor has shape (B, C, H, W).
            data_samples (List[:obj:`Det3DDataSample`], optional):
91
92
                The annotation data of every samples. Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.
93

94
95
96
97
98
99
        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
            - If ``mode="loss"``, return a dict of tensor.
100
        """
101
        if mode == 'loss':
102
            return self.loss(inputs, data_samples)
103
        elif mode == 'predict':
104
            return self.predict(inputs, data_samples)
105
        elif mode == 'tensor':
106
            return self._forward(inputs, data_samples)
107
        else:
108
109
110
111
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')

    @abstractmethod
112
113
    def loss(self, batch_inputs: dict,
             batch_data_samples: SampleList) -> Dict[str, Tensor]:
114
115
116
117
        """Calculate losses from a batch of inputs and data samples."""
        pass

    @abstractmethod
118
    def predict(self, batch_inputs: dict,
119
120
121
122
123
124
                batch_data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing."""
        pass

    @abstractmethod
125
126
127
    def _forward(self,
                 batch_inputs: dict,
                 batch_data_samples: OptSampleList = None) -> Tensor:
128
129
130
131
132
133
134
135
        """Network forward process.

        Usually includes backbone, neck and head forward without any post-
        processing.
        """
        pass

    @abstractmethod
136
    def aug_test(self, batch_inputs, batch_data_samples):
137
138
        """Placeholder for augmentation test."""
        pass
139

140
    def postprocess_result(self, seg_pred_list: List[dict],
141
                           batch_data_samples: SampleList) -> SampleList:
142
143
        """Convert results list to `Det3DDataSample`.

144
        Args:
145
146
            seg_logits_list (List[dict]): List of segmentation results,
                seg_logits from model of each input point clouds sample.
147
148
149
            batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
                samples. It usually includes information such as `metainfo` and
                `gt_pts_seg`.
150
151

        Returns:
152
153
            List[:obj:`Det3DDataSample`]: Segmentation results of the input
            points. Each Det3DDataSample usually contains:
154

155
156
            - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
              segmentation.
157
        """
158

159
160
        for i in range(len(seg_pred_list)):
            seg_pred = seg_pred_list[i]
161
            batch_data_samples[i].set_data(
162
                {'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
163
        return batch_data_samples