s3dis_dataset.py 15 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from os import path as osp
3
from typing import Any, Callable, List, Optional, Tuple, Union
4

5
6
import numpy as np

7
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.structures import DepthInstance3DBoxes
jshilong's avatar
jshilong committed
9
from .det3d_dataset import Det3DDataset
ZCMax's avatar
ZCMax committed
10
from .seg3d_dataset import Seg3DDataset
11
12
13


@DATASETS.register_module()
jshilong's avatar
jshilong committed
14
class S3DISDataset(Det3DDataset):
15
16
17
18
19
20
    r"""S3DIS Dataset for Detection Task.

    This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
    often train on 5 of them and test on the remaining one. The one for
    test is Area_5 as suggested in `GSDN <https://arxiv.org/abs/2006.12356>`_.
    To concatenate 5 areas during training
21
    `mmengine.datasets.dataset_wrappers.ConcatDataset` should be used.
22
23
24
25

    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
26
27
28
29
30
31
32
33
34
35
36
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for data. Defaults to
            dict(pts='points',
                 pts_instance_mask='instance_mask',
                 pts_semantic_mask='semantic_mask').
        pipeline (List[dict]): Pipeline used for data processing.
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
        box_type_3d (str): Type of 3D box of this dataset.
37
38
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
39
            Defaults to 'Depth' in this dataset. Available options includes:
40
41
42
43

            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
44
45
46
47
48
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
        test_mode (bool): Whether the dataset is in test mode.
49
50
            Defaults to False.
    """
51
52
53
54
55
56
    METAINFO = {
        'classes': ('table', 'chair', 'sofa', 'bookcase', 'board'),
        # the valid ids of segmentation annotations
        'seg_valid_class_ids': (7, 8, 9, 10, 11),
        'seg_all_class_ids': tuple(range(1, 14))  # possibly with 'stair' class
    }
57
58

    def __init__(self,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                 data_root: str,
                 ann_file: str,
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
                     pts='points',
                     pts_instance_mask='instance_mask',
                     pts_semantic_mask='semantic_mask'),
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_camera=False, use_lidar=True),
                 box_type_3d: str = 'Depth',
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
                 **kwargs) -> None:

        # construct seg_label_mapping for semantic mask
        seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
        seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
        neg_label = len(seg_valid_cat_ids)
        seg_label_mapping = np.ones(
            seg_max_cat_id + 1, dtype=np.int) * neg_label
        for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
            seg_label_mapping[cat_id] = cls_idx
        self.seg_label_mapping = seg_label_mapping

83
84
85
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
86
87
            metainfo=metainfo,
            data_prefix=data_prefix,
88
89
90
91
            pipeline=pipeline,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
92
            test_mode=test_mode,
93
94
95
96
97
98
            **kwargs)

        self.metainfo['seg_label_mapping'] = self.seg_label_mapping
        assert 'use_camera' in self.modality and \
               'use_lidar' in self.modality
        assert self.modality['use_camera'] or self.modality['use_lidar']
99

100
101
    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.
102
103

        Args:
104
            info (dict): Raw info dict.
105
106

        Returns:
107
108
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
109
        """
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        info['pts_instance_mask_path'] = osp.join(
            self.data_prefix.get('pts_instance_mask', ''),
            info['pts_instance_mask_path'])
        info['pts_semantic_mask_path'] = osp.join(
            self.data_prefix.get('pts_semantic_mask', ''),
            info['pts_semantic_mask_path'])

        info = super().parse_data_info(info)
        # only be used in `PointSegClassMapping` in pipeline
        # to map original semantic class to valid category ids.
        info['seg_label_mapping'] = self.seg_label_mapping
        return info

    def parse_ann_info(self, info: dict) -> dict:
        """Process the `instances` in data info to `ann_info`.
125
126

        Args:
127
            info (dict): Info dict.
128
129

        Returns:
130
            dict: Processed `ann_info`.
131
        """
132
133
134
135
136
137
138
        ann_info = super().parse_ann_info(info)
        # empty gt
        if ann_info is None:
            ann_info = dict()
            ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
        # to target box structure
139

140
141
142
143
144
145
146
        ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
            ann_info['gt_bboxes_3d'],
            box_dim=ann_info['gt_bboxes_3d'].shape[-1],
            with_yaw=False,
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

        return ann_info
147
148


ZCMax's avatar
ZCMax committed
149
class _S3DISSegDataset(Seg3DDataset):
150
151
152
153
154
155
156
157
158
159
160
161
    r"""S3DIS Dataset for Semantic Segmentation Task.

    This class is the inner dataset for S3DIS. Since S3DIS has 6 areas, we
    often train on 5 of them and test on the remaining one.
    However, there is not a fixed train-test split of S3DIS. People often test
    on Area_5 as suggested by `SEGCloud <https://arxiv.org/abs/1710.07563>`_.
    But many papers also report the average results of 6-fold cross validation
    over the 6 areas (e.g. `DGCNN <https://arxiv.org/abs/1801.07829>`_).
    Therefore, we use an inner dataset for one area, and further use a dataset
    wrapper to concat all the provided data in different areas.

    Args:
162
163
164
165
166
        data_root (str, optional): Path of dataset root, Defaults to None.
        ann_file (str): Path of annotation file. Defaults to ''.
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
167
168
            dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
        pipeline (List[dict]): Pipeline used for data processing.
169
170
171
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
172
        ignore_index (int, optional): The label index to be ignored, e.g.
173
            unannotated points. If None is given, set to len(self.classes) to
174
            be consistent with PointSegClassMapping function in pipeline.
175
            Defaults to None.
176
        scene_idxs (np.ndarray or str, optional): Precomputed index to load
177
178
            data. For scenes with many points, we may sample it several times.
            Defaults to None.
179
180
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
181
    """
ZCMax's avatar
ZCMax committed
182
    METAINFO = {
183
        'classes':
ZCMax's avatar
ZCMax committed
184
185
        ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
         'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'),
186
        'palette': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
ZCMax's avatar
ZCMax committed
187
188
189
                    [255, 0, 255], [100, 100, 255], [200, 200, 100],
                    [170, 120, 200], [255, 0, 0], [200, 100, 100],
                    [10, 200, 100], [200, 200, 200], [50, 50, 50]],
190
        'seg_valid_class_ids':
ZCMax's avatar
ZCMax committed
191
        tuple(range(13)),
192
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
193
194
        tuple(range(14))  # possibly with 'stair' class
    }
195
196

    def __init__(self,
ZCMax's avatar
ZCMax committed
197
198
199
200
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
201
                     pts='points', pts_instance_mask='', pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
202
203
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
204
205
206
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
207
                 **kwargs) -> None:
208
209
210
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
ZCMax's avatar
ZCMax committed
211
212
            metainfo=metainfo,
            data_prefix=data_prefix,
213
214
215
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
216
            scene_idxs=scene_idxs,
ZCMax's avatar
ZCMax committed
217
            test_mode=test_mode,
218
            **kwargs)
219

220
221
    def get_scene_idxs(self, scene_idxs: Union[np.ndarray, str,
                                               None]) -> np.ndarray:
222
        """Compute scene_idxs for data sampling.
223

224
        We sample more times for scenes with more points.
225
226
227
228
229
230
        """
        # when testing, we load one whole scene every time
        if not self.test_mode and scene_idxs is None:
            raise NotImplementedError(
                'please provide re-sampled scene indexes for training')

231
        return super().get_scene_idxs(scene_idxs)
232
233
234
235
236
237
238
239
240


@DATASETS.register_module()
class S3DISSegDataset(_S3DISSegDataset):
    r"""S3DIS Dataset for Semantic Segmentation Task.

    This class serves as the API for experiments on the S3DIS Dataset.
    It wraps the provided datasets of different areas.
    We don't use `mmdet.datasets.dataset_wrappers.ConcatDataset` because we
241
    need to concat the `scene_idxs` of different areas.
242
243
244
245
246
247

    Please refer to the `google form <https://docs.google.com/forms/d/e/1FAIpQL
    ScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1>`_ for
    data downloading.

    Args:
248
        data_root (str, optional): Path of dataset root. Defaults to None.
249
        ann_files (List[str]): Path of several annotation files.
250
251
252
253
            Defaults to ''.
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
254
255
            dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
        pipeline (List[dict]): Pipeline used for data processing.
256
257
258
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
259
        ignore_index (int, optional): The label index to be ignored, e.g.
260
            unannotated points. If None is given, set to len(self.classes) to
261
            be consistent with PointSegClassMapping function in pipeline.
262
            Defaults to None.
263
        scene_idxs (List[np.ndarray] | List[str], optional): Precomputed index
264
265
266
267
            to load data. For scenes with many points, we may sample it
            several times. Defaults to None.
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
268
269
270
    """

    def __init__(self,
ZCMax's avatar
ZCMax committed
271
                 data_root: Optional[str] = None,
272
                 ann_files: List[str] = '',
ZCMax's avatar
ZCMax committed
273
274
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
275
                     pts='points', pts_instance_mask='', pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
276
277
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
278
279
280
281
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[List[np.ndarray],
                                            List[str]]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
282
                 **kwargs) -> None:
283

284
        # make sure that ann_files and scene_idxs have same length
285
286
287
288
289
290
291
        ann_files = self._check_ann_files(ann_files)
        scene_idxs = self._check_scene_idxs(scene_idxs, len(ann_files))

        # initialize some attributes as datasets[0]
        super().__init__(
            data_root=data_root,
            ann_file=ann_files[0],
ZCMax's avatar
ZCMax committed
292
293
            metainfo=metainfo,
            data_prefix=data_prefix,
294
295
296
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
297
            scene_idxs=scene_idxs[0],
ZCMax's avatar
ZCMax committed
298
            test_mode=test_mode,
299
            **kwargs)
300
301
302
303
304

        datasets = [
            _S3DISSegDataset(
                data_root=data_root,
                ann_file=ann_files[i],
ZCMax's avatar
ZCMax committed
305
306
                metainfo=metainfo,
                data_prefix=data_prefix,
307
308
309
                pipeline=pipeline,
                modality=modality,
                ignore_index=ignore_index,
310
                scene_idxs=scene_idxs[i],
ZCMax's avatar
ZCMax committed
311
                test_mode=test_mode,
312
                **kwargs) for i in range(len(ann_files))
313
314
        ]

ZCMax's avatar
ZCMax committed
315
316
        # data_list and scene_idxs need to be concat
        self.concat_data_list([dst.data_list for dst in datasets])
317
318
319
320
321

        # set group flag for the sampler
        if not self.test_mode:
            self._set_group_flag()

322
    def concat_data_list(self, data_lists: List[List[dict]]) -> None:
ZCMax's avatar
ZCMax committed
323
        """Concat data_list from several datasets to form self.data_list.
324
325

        Args:
326
327
            data_lists (List[List[dict]]): List of dict containing
                annotation information.
328
        """
ZCMax's avatar
ZCMax committed
329
330
        self.data_list = [
            data for data_list in data_lists for data in data_list
331
332
333
        ]

    @staticmethod
334
    def _duplicate_to_list(x: Any, num: int) -> list:
335
336
337
        """Repeat x `num` times to form a list."""
        return [x for _ in range(num)]

338
339
    def _check_ann_files(
            self, ann_file: Union[List[str], Tuple[str], str]) -> List[str]:
340
341
342
343
344
345
        """Make ann_files as list/tuple."""
        # ann_file could be str
        if not isinstance(ann_file, (list, tuple)):
            ann_file = self._duplicate_to_list(ann_file, 1)
        return ann_file

346
347
348
349
    def _check_scene_idxs(self, scene_idx: Union[str, List[Union[list, tuple,
                                                                 np.ndarray]],
                                                 List[str], None],
                          num: int) -> List[np.ndarray]:
350
351
352
353
354
355
356
357
358
359
360
361
        """Make scene_idxs as list/tuple."""
        if scene_idx is None:
            return self._duplicate_to_list(scene_idx, num)
        # scene_idx could be str, np.ndarray, list or tuple
        if isinstance(scene_idx, str):  # str
            return self._duplicate_to_list(scene_idx, num)
        if isinstance(scene_idx[0], str):  # list of str
            return scene_idx
        if isinstance(scene_idx[0], (list, tuple, np.ndarray)):  # list of idx
            return scene_idx
        # single idx
        return self._duplicate_to_list(scene_idx, num)