seg3d_dataset.py 13.5 KB
Newer Older
ZCMax's avatar
ZCMax committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
from os import path as osp
3
from typing import Callable, List, Optional, Sequence, Union
ZCMax's avatar
ZCMax committed
4
5
6

import numpy as np
from mmengine.dataset import BaseDataset
7
from mmengine.fileio import get_local_path
ZCMax's avatar
ZCMax committed
8
9
10
11
12
13
14
15
16
17
18

from mmdet3d.registry import DATASETS


@DATASETS.register_module()
class Seg3DDataset(BaseDataset):
    """Base Class for 3D semantic segmentation dataset.

    This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.

    Args:
19
20
        data_root (str, optional): Path of dataset root. Defaults to None.
        ann_file (str): Path of annotation file. Defaults to ''.
ZCMax's avatar
ZCMax committed
21
22
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
23
        data_prefix (dict): Prefix for training data. Defaults to
24
25
26
27
28
            dict(pts='points',
                 img='',
                 pts_instance_mask='',
                 pts_semantic_mask='').
        pipeline (List[dict]): Pipeline used for data processing.
29
30
31
            Defaults to [].
        modality (dict): Modality to specify the sensor data used
            as input, it usually has following keys:
ZCMax's avatar
ZCMax committed
32
33
34

                - use_camera: bool
                - use_lidar: bool
35
            Defaults to dict(use_lidar=True, use_camera=False).
ZCMax's avatar
ZCMax committed
36
        ignore_index (int, optional): The label index to be ignored, e.g.
37
            unannotated points. If None is given, set to len(self.classes) to
ZCMax's avatar
ZCMax committed
38
39
            be consistent with PointSegClassMapping function in pipeline.
            Defaults to None.
40
        scene_idxs (np.ndarray or str, optional): Precomputed index to load
ZCMax's avatar
ZCMax committed
41
42
            data. For scenes with many points, we may sample it several times.
            Defaults to None.
43
44
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
45
46
47
48
        serialize_data (bool): Whether to hold memory using serialized objects,
            when enabled, data loader workers can use shared RAM from master
            process instead of making a copy.
            Defaults to False for 3D Segmentation datasets.
49
50
51
        load_eval_anns (bool): Whether to load annotations in test_mode,
            the annotation will be save in `eval_ann_infos`, which can be used
            in Evaluator. Defaults to True.
52
53
        backend_args (dict, optional): Arguments to instantiate the
            corresponding backend. Defaults to None.
ZCMax's avatar
ZCMax committed
54
55
    """
    METAINFO = {
56
57
        'classes': None,  # names of all classes data used for the task
        'palette': None,  # official color for visualization
58
59
        'seg_valid_class_ids': None,  # class_ids used for training
        'seg_all_class_ids': None,  # all possible class_ids in loaded seg mask
ZCMax's avatar
ZCMax committed
60
61
62
63
64
65
66
67
68
69
    }

    def __init__(self,
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
                     pts='points',
                     img='',
                     pts_instance_mask='',
70
                     pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
71
72
73
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
                 ignore_index: Optional[int] = None,
74
                 scene_idxs: Optional[Union[str, np.ndarray]] = None,
ZCMax's avatar
ZCMax committed
75
                 test_mode: bool = False,
76
                 serialize_data: bool = False,
ZCMax's avatar
ZCMax committed
77
                 load_eval_anns: bool = True,
78
                 backend_args: Optional[dict] = None,
ZCMax's avatar
ZCMax committed
79
                 **kwargs) -> None:
80
        self.backend_args = backend_args
ZCMax's avatar
ZCMax committed
81
82
83
84
85
        self.modality = modality
        self.load_eval_anns = load_eval_anns

        # TODO: We maintain the ignore_index attributes,
        # but we may consider to remove it in the future.
86
        self.ignore_index = len(self.METAINFO['classes']) if \
ZCMax's avatar
ZCMax committed
87
88
89
            ignore_index is None else ignore_index

        # Get label mapping for custom classes
90
        new_classes = metainfo.get('classes', None)
ZCMax's avatar
ZCMax committed
91

92
        self.label_mapping, self.label2cat, seg_valid_class_ids = \
ZCMax's avatar
ZCMax committed
93
94
95
96
            self.get_label_mapping(new_classes)

        metainfo['label_mapping'] = self.label_mapping
        metainfo['label2cat'] = self.label2cat
ChaimZhu's avatar
ChaimZhu committed
97
        metainfo['ignore_index'] = self.ignore_index
98
        metainfo['seg_valid_class_ids'] = seg_valid_class_ids
ZCMax's avatar
ZCMax committed
99
100
101
102

        # generate palette if it is not defined based on
        # label mapping, otherwise directly use palette
        # defined in dataset config.
103
        palette = metainfo.get('palette', None)
ZCMax's avatar
ZCMax committed
104
105
        updated_palette = self._update_palette(new_classes, palette)

106
        metainfo['palette'] = updated_palette
ZCMax's avatar
ZCMax committed
107

108
        # construct seg_label_mapping for semantic mask
109
        self.seg_label_mapping = self.get_seg_label_mapping(metainfo)
110

ZCMax's avatar
ZCMax committed
111
112
113
114
115
116
117
        super().__init__(
            ann_file=ann_file,
            metainfo=metainfo,
            data_root=data_root,
            data_prefix=data_prefix,
            pipeline=pipeline,
            test_mode=test_mode,
118
            serialize_data=serialize_data,
ZCMax's avatar
ZCMax committed
119
120
            **kwargs)

121
        self.metainfo['seg_label_mapping'] = self.seg_label_mapping
122
123
124
        if not kwargs.get('lazy_init', False):
            self.scene_idxs = self.get_scene_idxs(scene_idxs)
            self.data_list = [self.data_list[i] for i in self.scene_idxs]
ZCMax's avatar
ZCMax committed
125

126
127
128
            # set group flag for the sampler
            if not self.test_mode:
                self._set_group_flag()
ZCMax's avatar
ZCMax committed
129
130

    def get_label_mapping(self,
131
                          new_classes: Optional[Sequence] = None) -> tuple:
ZCMax's avatar
ZCMax committed
132
133
134
135
136
137
138
139
140
        """Get label mapping.

        The ``label_mapping`` is a dictionary, its keys are the old label ids
        and its values are the new label ids, and is used for changing pixel
        labels in load_annotations. If and only if old classes in cls.METAINFO
        is not equal to new classes in self._metainfo and nether of them is not
        None, `label_mapping` is not None.

        Args:
141
142
            new_classes (list or tuple, optional): The new classes name from
                metainfo. Defaults to None.
ZCMax's avatar
ZCMax committed
143
144
145

        Returns:
            tuple: The mapping from old classes in cls.METAINFO to
146
            new classes in metainfo
ZCMax's avatar
ZCMax committed
147
        """
148
        old_classes = self.METAINFO.get('classes', None)
ZCMax's avatar
ZCMax committed
149
150
151
152
153
        if (new_classes is not None and old_classes is not None
                and list(new_classes) != list(old_classes)):
            if not set(new_classes).issubset(old_classes):
                raise ValueError(
                    f'new classes {new_classes} is not a '
154
                    f'subset of classes {old_classes} in METAINFO.')
ZCMax's avatar
ZCMax committed
155
156
157

            # obtain true id from valid_class_ids
            valid_class_ids = [
158
159
                self.METAINFO['seg_valid_class_ids'][old_classes.index(
                    cls_name)] for cls_name in new_classes
ZCMax's avatar
ZCMax committed
160
161
162
            ]
            label_mapping = {
                cls_id: self.ignore_index
163
                for cls_id in self.METAINFO['seg_all_class_ids']
ZCMax's avatar
ZCMax committed
164
165
166
167
168
169
170
171
            }
            label_mapping.update(
                {cls_id: i
                 for i, cls_id in enumerate(valid_class_ids)})
            label2cat = {i: cat_name for i, cat_name in enumerate(new_classes)}
        else:
            label_mapping = {
                cls_id: self.ignore_index
172
                for cls_id in self.METAINFO['seg_all_class_ids']
ZCMax's avatar
ZCMax committed
173
174
175
            }
            label_mapping.update({
                cls_id: i
176
177
                for i, cls_id in enumerate(
                    self.METAINFO['seg_valid_class_ids'])
ZCMax's avatar
ZCMax committed
178
179
180
181
            })
            # map label to category name
            label2cat = {
                i: cat_name
182
                for i, cat_name in enumerate(self.METAINFO['classes'])
ZCMax's avatar
ZCMax committed
183
            }
184
            valid_class_ids = self.METAINFO['seg_valid_class_ids']
ZCMax's avatar
ZCMax committed
185
186
187

        return label_mapping, label2cat, valid_class_ids

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def get_seg_label_mapping(self, metainfo=None):
        """Get segmentation label mapping.

        The ``seg_label_mapping`` is an array, its indices are the old label
        ids and its values are the new label ids, and is specifically used
        for changing point labels in PointSegClassMapping.

        Args:
            metainfo (dict, optional): Meta information to set
            seg_label_mapping. Defaults to None.

        Returns:
            tuple: The mapping from old classes to new classes.
        """
        seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
        seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
        neg_label = len(seg_valid_cat_ids)
        seg_label_mapping = np.ones(
206
            seg_max_cat_id + 1, dtype=np.int64) * neg_label
207
208
209
210
        for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
            seg_label_mapping[cat_id] = cls_idx
        return seg_label_mapping

211
212
    def _update_palette(self, new_classes: list, palette: Union[None,
                                                                list]) -> list:
ZCMax's avatar
ZCMax committed
213
214
215
216
217
218
219
220
221
222
223
224
        """Update palette according to metainfo.

        If length of palette is equal to classes, just return the palette.
        If palette is not defined, it will randomly generate a palette.
        If classes is updated by customer, it will return the subset of
        palette.

        Returns:
            Sequence: Palette for current dataset.
        """
        if palette is None:
            # If palette is not defined, it generate a palette according
225
226
            # to the original palette and classes.
            old_classes = self.METAINFO.get('classes', None)
ZCMax's avatar
ZCMax committed
227
            palette = [
228
                self.METAINFO['palette'][old_classes.index(cls_name)]
ZCMax's avatar
ZCMax committed
229
230
231
232
233
234
235
236
                for cls_name in new_classes
            ]
            return palette

        # palette does match classes
        if len(palette) == len(new_classes):
            return palette
        else:
237
238
            raise ValueError('Once palette in set in metainfo, it should'
                             'match classes in metainfo')
ZCMax's avatar
ZCMax committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.

        Convert all relative path of needed modality data file to
        the absolute path. And process
        the `instances` field to `ann_info` in training stage.

        Args:
            info (dict): Raw info dict.

        Returns:
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
        """
        if self.modality['use_lidar']:
            info['lidar_points']['lidar_path'] = \
                osp.join(
                    self.data_prefix.get('pts', ''),
                    info['lidar_points']['lidar_path'])
259
260
            if 'num_pts_feats' in info['lidar_points']:
                info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
261
262
            info['lidar_path'] = info['lidar_points']['lidar_path']

ZCMax's avatar
ZCMax committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        if self.modality['use_camera']:
            for cam_id, img_info in info['images'].items():
                if 'img_path' in img_info:
                    img_info['img_path'] = osp.join(
                        self.data_prefix.get('img', ''), img_info['img_path'])

        if 'pts_instance_mask_path' in info:
            info['pts_instance_mask_path'] = \
                osp.join(self.data_prefix.get('pts_instance_mask', ''),
                         info['pts_instance_mask_path'])

        if 'pts_semantic_mask_path' in info:
            info['pts_semantic_mask_path'] = \
                osp.join(self.data_prefix.get('pts_semantic_mask', ''),
                         info['pts_semantic_mask_path'])

279
280
281
        # only be used in `PointSegClassMapping` in pipeline
        # to map original semantic class to valid category ids.
        info['seg_label_mapping'] = self.seg_label_mapping
ZCMax's avatar
ZCMax committed
282

zhangshilong's avatar
zhangshilong committed
283
        # 'eval_ann_info' will be updated in loading transforms
ZCMax's avatar
ZCMax committed
284
285
286
287
288
        if self.test_mode and self.load_eval_anns:
            info['eval_ann_info'] = dict()

        return info

289
290
291
292
293
294
295
296
297
298
299
300
    def prepare_data(self, idx: int) -> dict:
        """Get data processed by ``self.pipeline``.

        Args:
            idx (int): The index of ``data_info``.

        Returns:
            dict: Results passed through ``self.pipeline``.
        """
        if not self.test_mode:
            data_info = self.get_data_info(idx)
            # Pass the dataset to the pipeline during training to support mixed
301
            # data augmentation, such as polarmix and lasermix.
302
303
304
305
306
            data_info['dataset'] = self
            return self.pipeline(data_info)
        else:
            return super().prepare_data(idx)

307
308
    def get_scene_idxs(self, scene_idxs: Union[None, str,
                                               np.ndarray]) -> np.ndarray:
ZCMax's avatar
ZCMax committed
309
310
311
312
313
314
        """Compute scene_idxs for data sampling.

        We sample more times for scenes with more points.
        """
        if self.test_mode:
            # when testing, we load one whole scene every time
ChaimZhu's avatar
ChaimZhu committed
315
            return np.arange(len(self)).astype(np.int32)
ZCMax's avatar
ZCMax committed
316
317
318
319

        # we may need to re-sample different scenes according to scene_idxs
        # this is necessary for indoor scene segmentation such as ScanNet
        if scene_idxs is None:
ChaimZhu's avatar
ChaimZhu committed
320
            scene_idxs = np.arange(len(self))
ZCMax's avatar
ZCMax committed
321
        if isinstance(scene_idxs, str):
ChaimZhu's avatar
ChaimZhu committed
322
            scene_idxs = osp.join(self.data_root, scene_idxs)
323
324
            with get_local_path(
                    scene_idxs, backend_args=self.backend_args) as local_path:
ZCMax's avatar
ZCMax committed
325
326
327
328
329
330
                scene_idxs = np.load(local_path)
        else:
            scene_idxs = np.array(scene_idxs)

        return scene_idxs.astype(np.int32)

331
    def _set_group_flag(self) -> None:
ZCMax's avatar
ZCMax committed
332
333
334
335
336
337
338
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0. In 3D datasets, they are all the same, thus are all
        zeros.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)