base.py 6.44 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from abc import ABCMeta, abstractmethod
3
from typing import Dict, List, Union
4

5
6
from mmengine.model import BaseModel
from torch import Tensor
7

8
from mmdet3d.structures import PointData
zhangshilong's avatar
zhangshilong committed
9
10
11
from mmdet3d.structures.det3d_data_sample import (ForwardResults,
                                                  OptSampleList, SampleList)
from mmdet3d.utils import OptConfigType, OptMultiConfig
12
13


14
class Base3DSegmentor(BaseModel, metaclass=ABCMeta):
15
16
    """Base class for 3D segmentors.

17
    Args:
18
19
20
21
22
23
        data_preprocessor (dict or ConfigDict, optional): Model preprocessing
            config for processing the input data. it usually includes
            ``to_rgb``, ``pad_size_divisor``, ``pad_val``, ``mean`` and
            ``std``. Defaults to None.
       init_cfg (dict or ConfigDict, optional): The config to control the
           initialization. Defaults to None.
24
25
    """

26
27
28
29
30
31
32
33
    def __init__(self,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None):
        super(Base3DSegmentor, self).__init__(
            data_preprocessor=data_preprocessor, init_cfg=init_cfg)

    @property
    def with_neck(self) -> bool:
34
        """bool: Whether the segmentor has neck."""
35
36
37
38
        return hasattr(self, 'neck') and self.neck is not None

    @property
    def with_auxiliary_head(self) -> bool:
39
        """bool: Whether the segmentor has auxiliary head."""
40
41
42
        return hasattr(self,
                       'auxiliary_head') and self.auxiliary_head is not None

43
    @property
44
    def with_decode_head(self) -> bool:
45
        """bool: Whether the segmentor has decode head."""
46
47
48
49
        return hasattr(self, 'decode_head') and self.decode_head is not None

    @property
    def with_regularization_loss(self) -> bool:
50
        """bool: Whether the segmentor has regularization loss for weight."""
51
52
53
        return hasattr(self, 'loss_regularization') and \
            self.loss_regularization is not None

54
    @abstractmethod
55
    def extract_feat(self, batch_inputs: Tensor) -> dict:
56
57
58
59
60
        """Placeholder for extract features from images."""
        pass

    @abstractmethod
    def encode_decode(self, batch_inputs: Tensor,
61
                      batch_data_samples: SampleList) -> Tensor:
62
63
64
65
66
        """Placeholder for encode images with backbone and decode into a
        semantic segmentation map of the same size as input."""
        pass

    def forward(self,
67
68
                inputs: Union[dict, List[dict]],
                data_samples: OptSampleList = None,
69
70
71
72
73
74
                mode: str = 'tensor') -> ForwardResults:
        """The unified entry for a forward process in both training and test.

        The method should accept three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor or tuple of
75
          tensor without any post-processing, same as a common nn.Module.
76
        - "predict": Forward and return the predictions, which are fully
77
          processed to a list of :obj:`SegDataSample`.
78
        - "loss": Forward and return a dict of losses according to the given
79
          inputs and data samples.
80
81
82

        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.
83
84

        Args:
85
86
            inputs (dict or List[dict]): Input sample dict which includes
                'points' and 'imgs' keys.
87

88
89
90
                - points (List[Tensor]): Point cloud of each sample.
                - imgs (Tensor): Image tensor has shape (B, C, H, W).
            data_samples (List[:obj:`Det3DDataSample`], optional):
91
92
                The annotation data of every samples. Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.
93

94
95
96
97
98
99
        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
            - If ``mode="loss"``, return a dict of tensor.
100
        """
101
        if mode == 'loss':
102
            return self.loss(inputs, data_samples)
103
        elif mode == 'predict':
104
            return self.predict(inputs, data_samples)
105
        elif mode == 'tensor':
106
            return self._forward(inputs, data_samples)
107
        else:
108
109
110
111
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')

    @abstractmethod
112
113
    def loss(self, batch_inputs: dict,
             batch_data_samples: SampleList) -> Dict[str, Tensor]:
114
115
116
117
        """Calculate losses from a batch of inputs and data samples."""
        pass

    @abstractmethod
118
    def predict(self, batch_inputs: dict,
119
120
121
122
123
124
                batch_data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing."""
        pass

    @abstractmethod
125
126
127
    def _forward(self,
                 batch_inputs: dict,
                 batch_data_samples: OptSampleList = None) -> Tensor:
128
129
130
131
132
133
134
        """Network forward process.

        Usually includes backbone, neck and head forward without any post-
        processing.
        """
        pass

135
    def postprocess_result(self, seg_logits_list: List[Tensor],
136
                           batch_data_samples: SampleList) -> SampleList:
137
138
        """Convert results list to `Det3DDataSample`.

139
        Args:
140
            seg_logits_list (List[Tensor]): List of segmentation results,
141
                seg_logits from model of each input point clouds sample.
142
143
144
            batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
                samples. It usually includes information such as `metainfo` and
                `gt_pts_seg`.
145
146

        Returns:
147
148
            List[:obj:`Det3DDataSample`]: Segmentation results of the input
            points. Each Det3DDataSample usually contains:
149

150
            - ``pred_pts_seg`` (PointData): Prediction of 3D semantic
151
              segmentation.
152
153
            - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
              segmentation before normalization.
154
        """
155

156
157
158
159
160
161
162
163
164
        for i in range(len(seg_logits_list)):
            seg_logits = seg_logits_list[i]
            seg_pred = seg_logits.argmax(dim=0)
            batch_data_samples[i].set_data({
                'pts_seg_logits':
                PointData(**{'pts_seg_logits': seg_logits}),
                'pred_pts_seg':
                PointData(**{'pts_semantic_mask': seg_pred})
            })
165
        return batch_data_samples