s3dis_dataset.py 15.2 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from os import path as osp
3
from typing import Any, Callable, List, Optional, Tuple, Union
4

5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.structures import DepthInstance3DBoxes
jshilong's avatar
jshilong committed
9
from .det3d_dataset import Det3DDataset
ZCMax's avatar
ZCMax committed
10
from .seg3d_dataset import Seg3DDataset
11
12
13


@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class S3DISDataset(Det3DDataset):
15
16
17
18
19
20
    r"""S3DIS Dataset for Detection Task.

    This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
    often train on 5 of them and test on the remaining one. The one for
    test is Area_5 as suggested in `GSDN <https://arxiv.org/abs/2006.12356>`_.
    To concatenate 5 areas during training
21
    `mmengine.datasets.dataset_wrappers.ConcatDataset` should be used.
22
23
24
25

    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
26
27
28
29
30
31
32
33
34
35
36
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for data. Defaults to
            dict(pts='points',
                 pts_instance_mask='instance_mask',
                 pts_semantic_mask='semantic_mask').
        pipeline (List[dict]): Pipeline used for data processing.
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
        box_type_3d (str): Type of 3D box of this dataset.
37
38
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
39
            Defaults to 'Depth' in this dataset. Available options includes:
40
41
42
43

            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
44
45
46
47
48
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
        test_mode (bool): Whether the dataset is in test mode.
49
50
            Defaults to False.
    """
51
52
53
54
    METAINFO = {
        'classes': ('table', 'chair', 'sofa', 'bookcase', 'board'),
        # the valid ids of segmentation annotations
        'seg_valid_class_ids': (7, 8, 9, 10, 11),
55
56
57
58
        'seg_all_class_ids':
        tuple(range(1, 14)),  # possibly with 'stair' class
        'palette': [(170, 120, 200), (255, 0, 0), (200, 100, 100),
                    (10, 200, 100), (200, 200, 200)]
59
    }
60
61

    def __init__(self,
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                 data_root: str,
                 ann_file: str,
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
                     pts='points',
                     pts_instance_mask='instance_mask',
                     pts_semantic_mask='semantic_mask'),
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_camera=False, use_lidar=True),
                 box_type_3d: str = 'Depth',
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 **kwargs) -> None:

        # construct seg_label_mapping for semantic mask
        seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
        seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
        neg_label = len(seg_valid_cat_ids)
        seg_label_mapping = np.ones(
81
            seg_max_cat_id + 1, dtype=np.int64) * neg_label
82
83
84
85
        for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
            seg_label_mapping[cat_id] = cls_idx
        self.seg_label_mapping = seg_label_mapping

86
87
88
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
89
90
            metainfo=metainfo,
            data_prefix=data_prefix,
91
92
93
94
            pipeline=pipeline,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
95
            test_mode=test_mode,
96
97
98
99
100
101
            **kwargs)

        self.metainfo['seg_label_mapping'] = self.seg_label_mapping
        assert 'use_camera' in self.modality and \
               'use_lidar' in self.modality
        assert self.modality['use_camera'] or self.modality['use_lidar']
102

103
104
    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.
105
106

        Args:
107
            info (dict): Raw info dict.
108
109

        Returns:
110
111
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
112
        """
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        info['pts_instance_mask_path'] = osp.join(
            self.data_prefix.get('pts_instance_mask', ''),
            info['pts_instance_mask_path'])
        info['pts_semantic_mask_path'] = osp.join(
            self.data_prefix.get('pts_semantic_mask', ''),
            info['pts_semantic_mask_path'])

        info = super().parse_data_info(info)
        # only be used in `PointSegClassMapping` in pipeline
        # to map original semantic class to valid category ids.
        info['seg_label_mapping'] = self.seg_label_mapping
        return info

    def parse_ann_info(self, info: dict) -> dict:
        """Process the `instances` in data info to `ann_info`.
128
129

        Args:
130
            info (dict): Info dict.
131
132

        Returns:
133
            dict: Processed `ann_info`.
134
        """
135
136
137
138
139
140
141
        ann_info = super().parse_ann_info(info)
        # empty gt
        if ann_info is None:
            ann_info = dict()
            ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
        # to target box structure
142

143
144
145
146
147
148
149
        ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
            ann_info['gt_bboxes_3d'],
            box_dim=ann_info['gt_bboxes_3d'].shape[-1],
            with_yaw=False,
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

        return ann_info
150
151


ZCMax's avatar
ZCMax committed
152
class _S3DISSegDataset(Seg3DDataset):
153
154
155
156
157
158
159
160
161
162
163
164
    r"""S3DIS Dataset for Semantic Segmentation Task.

    This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
    often train on 5 of them and test on the remaining one.
    However, there is not a fixed train-test split of S3DIS. People often test
    on Area_5 as suggested by `SEGCloud <https://arxiv.org/abs/1710.07563>`_.
    But many papers also report the average results of 6-fold cross validation
    over the 6 areas (e.g. `DGCNN <https://arxiv.org/abs/1801.07829>`_).
    Therefore, we use an inner dataset for one area, and further use a dataset
    wrapper to concat all the provided data in different areas.

    Args:
165
166
167
168
169
        data_root (str, optional): Path of dataset root, Defaults to None.
        ann_file (str): Path of annotation file. Defaults to ''.
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
170
171
            dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
        pipeline (List[dict]): Pipeline used for data processing.
172
173
174
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
175
        ignore_index (int, optional): The label index to be ignored, e.g.
176
            unannotated points. If None is given, set to len(self.classes) to
177
            be consistent with PointSegClassMapping function in pipeline.
178
            Defaults to None.
179
        scene_idxs (np.ndarray or str, optional): Precomputed index to load
180
181
            data. For scenes with many points, we may sample it several times.
            Defaults to None.
182
183
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
184
    """
ZCMax's avatar
ZCMax committed
185
    METAINFO = {
186
        'classes':
ZCMax's avatar
ZCMax committed
187
188
        ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
         'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'),
189
        'palette': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
ZCMax's avatar
ZCMax committed
190
191
192
                    [255, 0, 255], [100, 100, 255], [200, 200, 100],
                    [170, 120, 200], [255, 0, 0], [200, 100, 100],
                    [10, 200, 100], [200, 200, 200], [50, 50, 50]],
193
        'seg_valid_class_ids':
ZCMax's avatar
ZCMax committed
194
        tuple(range(13)),
195
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
196
197
        tuple(range(14))  # possibly with 'stair' class
    }
198
199

    def __init__(self,
ZCMax's avatar
ZCMax committed
200
201
202
203
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
204
                     pts='points', pts_instance_mask='', pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
205
206
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
207
208
209
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
210
                 **kwargs) -> None:
211
212
213
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
ZCMax's avatar
ZCMax committed
214
215
            metainfo=metainfo,
            data_prefix=data_prefix,
216
217
218
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
219
            scene_idxs=scene_idxs,
ZCMax's avatar
ZCMax committed
220
            test_mode=test_mode,
221
            **kwargs)
222

223
224
    def get_scene_idxs(self, scene_idxs: Union[np.ndarray, str,
                                               None]) -> np.ndarray:
225
        """Compute scene_idxs for data sampling.
226

227
        We sample more times for scenes with more points.
228
229
230
231
232
233
        """
        # when testing, we load one whole scene every time
        if not self.test_mode and scene_idxs is None:
            raise NotImplementedError(
                'please provide re-sampled scene indexes for training')

234
        return super().get_scene_idxs(scene_idxs)
235
236
237
238
239
240
241
242
243


@DATASETS.register_module()
class S3DISSegDataset(_S3DISSegDataset):
    r"""S3DIS Dataset for Semantic Segmentation Task.

    This class serves as the API for experiments on the S3DIS Dataset.
    It wraps the provided datasets of different areas.
    We don't use `mmdet.datasets.dataset_wrappers.ConcatDataset` because we
244
    need to concat the `scene_idxs` of different areas.
245
246
247
248
249
250

    Please refer to the `google form <https://docs.google.com/forms/d/e/1FAIpQL
    ScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1>`_ for
    data downloading.

    Args:
251
        data_root (str, optional): Path of dataset root. Defaults to None.
252
        ann_files (List[str]): Path of several annotation files.
253
254
255
256
            Defaults to ''.
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
257
258
            dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
        pipeline (List[dict]): Pipeline used for data processing.
259
260
261
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
262
        ignore_index (int, optional): The label index to be ignored, e.g.
263
            unannotated points. If None is given, set to len(self.classes) to
264
            be consistent with PointSegClassMapping function in pipeline.
265
            Defaults to None.
266
        scene_idxs (List[np.ndarray] | List[str], optional): Precomputed index
267
268
269
270
            to load data. For scenes with many points, we may sample it
            several times. Defaults to None.
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
271
272
273
    """

    def __init__(self,
ZCMax's avatar
ZCMax committed
274
                 data_root: Optional[str] = None,
275
                 ann_files: List[str] = '',
ZCMax's avatar
ZCMax committed
276
277
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
278
                     pts='points', pts_instance_mask='', pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
279
280
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
281
282
283
284
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[List[np.ndarray],
                                            List[str]]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
285
                 **kwargs) -> None:
286

287
        # make sure that ann_files and scene_idxs have same length
288
289
290
291
292
293
294
        ann_files = self._check_ann_files(ann_files)
        scene_idxs = self._check_scene_idxs(scene_idxs, len(ann_files))

        # initialize some attributes as datasets[0]
        super().__init__(
            data_root=data_root,
            ann_file=ann_files[0],
ZCMax's avatar
ZCMax committed
295
296
            metainfo=metainfo,
            data_prefix=data_prefix,
297
298
299
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
300
            scene_idxs=scene_idxs[0],
ZCMax's avatar
ZCMax committed
301
            test_mode=test_mode,
302
            **kwargs)
303
304
305
306
307

        datasets = [
            _S3DISSegDataset(
                data_root=data_root,
                ann_file=ann_files[i],
ZCMax's avatar
ZCMax committed
308
309
                metainfo=metainfo,
                data_prefix=data_prefix,
310
311
312
                pipeline=pipeline,
                modality=modality,
                ignore_index=ignore_index,
313
                scene_idxs=scene_idxs[i],
ZCMax's avatar
ZCMax committed
314
                test_mode=test_mode,
315
                **kwargs) for i in range(len(ann_files))
316
317
        ]

ZCMax's avatar
ZCMax committed
318
319
        # data_list and scene_idxs need to be concat
        self.concat_data_list([dst.data_list for dst in datasets])
320
321
322
323
324

        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()

325
    def concat_data_list(self, data_lists: List[List[dict]]) -> None:
ZCMax's avatar
ZCMax committed
326
        """Concat data_list from several datasets to form self.data_list.
327
328

        Args:
329
330
            data_lists (List[List[dict]]): List of dict containing
                annotation information.
331
        """
ZCMax's avatar
ZCMax committed
332
333
        self.data_list = [
            data for data_list in data_lists for data in data_list
334
335
336
        ]

    @staticmethod
337
    def _duplicate_to_list(x: Any, num: int) -> list:
338
339
340
        """Repeat x `num` times to form a list."""
        return [x for _ in range(num)]

341
342
    def _check_ann_files(
            self, ann_file: Union[List[str], Tuple[str], str]) -> List[str]:
343
344
345
346
347
348
        """Make ann_files as list/tuple."""
        # ann_file could be str
        if not isinstance(ann_file, (list, tuple)):
            ann_file = self._duplicate_to_list(ann_file, 1)
        return ann_file

349
350
351
352
    def _check_scene_idxs(self, scene_idx: Union[str, List[Union[list, tuple,
                                                                 np.ndarray]],
                                                 List[str], None],
                          num: int) -> List[np.ndarray]:
353
354
355
356
357
358
359
360
361
362
363
364
        """Make scene_idxs as list/tuple."""
        if scene_idx is None:
            return self._duplicate_to_list(scene_idx, num)
        # scene_idx could be str, np.ndarray, list or tuple
        if isinstance(scene_idx, str):  # str
            return self._duplicate_to_list(scene_idx, num)
        if isinstance(scene_idx[0], str):  # list of str
            return scene_idx
        if isinstance(scene_idx[0], (list, tuple, np.ndarray)):  # list of idx
            return scene_idx
        # single idx
        return self._duplicate_to_list(scene_idx, num)