scannet_dataset.py 13.3 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
zhangwenwei's avatar
zhangwenwei committed
3
from os import path as osp
ZCMax's avatar
ZCMax committed
4
from typing import Callable, List, Optional, Union
5

6
7
import numpy as np

8
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures import DepthInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
ZCMax's avatar
ZCMax committed
11
from .seg3d_dataset import Seg3DDataset
12
13
14


@DATASETS.register_module()
jshilong's avatar
jshilong committed
15
class ScanNetDataset(Det3DDataset):
16
    r"""ScanNet Dataset for Detection Task.
17

wangtai's avatar
wangtai committed
18
19
    This class serves as the API for experiments on the ScanNet Dataset.

zhangwenwei's avatar
zhangwenwei committed
20
21
    Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
    for data downloading.
wangtai's avatar
wangtai committed
22
23
24
25

    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
jshilong's avatar
jshilong committed
26
27
28
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for data. Defaults to
29
            dict(pts='points',
30
                 pts_instance_mask='instance_mask',
31
                 pts_semantic_mask='semantic_mask').
32
        pipeline (List[dict]): Pipeline used for data processing.
33
34
35
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
jshilong's avatar
jshilong committed
36
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
37
38
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
39
            Defaults to 'Depth' in this dataset. Available options includes:
wangtai's avatar
wangtai committed
40

wangtai's avatar
wangtai committed
41
42
43
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
44
45
46
47
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
jshilong's avatar
jshilong committed
48
        test_mode (bool): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
49
50
            Defaults to False.
    """
jshilong's avatar
jshilong committed
51
    METAINFO = {
52
        'classes':
jshilong's avatar
jshilong committed
53
54
        ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
         'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
55
56
57
58
59
         'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
        # the valid ids of segmentation annotations
        'seg_valid_class_ids':
        (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
        'seg_all_class_ids':
60
61
62
63
64
65
66
        tuple(range(1, 41)),
        'palette': [(31, 119, 180), (255, 187, 120), (188, 189, 34),
                    (140, 86, 75), (255, 152, 150), (214, 39, 40),
                    (197, 176, 213), (148, 103, 189), (196, 156, 148),
                    (23, 190, 207), (247, 182, 210), (219, 219, 141),
                    (255, 127, 14), (158, 218, 229), (44, 160, 44),
                    (112, 128, 144), (227, 119, 194), (82, 84, 163)]
jshilong's avatar
jshilong committed
67
    }
68
69

    def __init__(self,
jshilong's avatar
jshilong committed
70
71
                 data_root: str,
                 ann_file: str,
72
                 metainfo: Optional[dict] = None,
jshilong's avatar
jshilong committed
73
74
                 data_prefix: dict = dict(
                     pts='points',
75
                     pts_instance_mask='instance_mask',
jshilong's avatar
jshilong committed
76
77
                     pts_semantic_mask='semantic_mask'),
                 pipeline: List[Union[dict, Callable]] = [],
78
                 modality: dict = dict(use_camera=False, use_lidar=True),
jshilong's avatar
jshilong committed
79
80
81
                 box_type_3d: str = 'Depth',
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
82
                 **kwargs) -> None:
83
84
85
86
87
88

        # construct seg_label_mapping for semantic mask
        seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
        seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
        neg_label = len(seg_valid_cat_ids)
        seg_label_mapping = np.ones(
89
            seg_max_cat_id + 1, dtype=np.int64) * neg_label
90
91
92
93
        for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
            seg_label_mapping[cat_id] = cls_idx
        self.seg_label_mapping = seg_label_mapping

94
95
96
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
jshilong's avatar
jshilong committed
97
98
            metainfo=metainfo,
            data_prefix=data_prefix,
99
100
101
102
            pipeline=pipeline,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
103
104
            test_mode=test_mode,
            **kwargs)
105
106

        self.metainfo['seg_label_mapping'] = self.seg_label_mapping
107
        assert 'use_camera' in self.modality and \
jshilong's avatar
jshilong committed
108
109
               'use_lidar' in self.modality
        assert self.modality['use_camera'] or self.modality['use_lidar']
110

jshilong's avatar
jshilong committed
111
    @staticmethod
112
    def _get_axis_align_matrix(info: dict) -> np.ndarray:
jshilong's avatar
jshilong committed
113
        """Get axis_align_matrix from info. If not exist, return identity mat.
114
115

        Args:
jshilong's avatar
jshilong committed
116
            info (dict): Info of a single sample data.
117
118

        Returns:
jshilong's avatar
jshilong committed
119
            np.ndarray: 4x4 transformation matrix.
120
        """
jshilong's avatar
jshilong committed
121
122
        if 'axis_align_matrix' in info:
            return np.array(info['axis_align_matrix'])
123
        else:
jshilong's avatar
jshilong committed
124
125
126
127
            warnings.warn(
                'axis_align_matrix is not found in ScanNet data info, please '
                'use new pre-process scripts to re-generate ScanNet data')
            return np.eye(4).astype(np.float32)
liyinhao's avatar
liyinhao committed
128

jshilong's avatar
jshilong committed
129
130
    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.
131

jshilong's avatar
jshilong committed
132
133
        The only difference with it in `Det3DDataset`
        is the specific process for `axis_align_matrix'.
134
135

        Args:
jshilong's avatar
jshilong committed
136
            info (dict): Raw info dict.
137
138

        Returns:
139
140
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
141
        """
jshilong's avatar
jshilong committed
142
143
144
145
146
147
148
149
150
        info['axis_align_matrix'] = self._get_axis_align_matrix(info)
        info['pts_instance_mask_path'] = osp.join(
            self.data_prefix.get('pts_instance_mask', ''),
            info['pts_instance_mask_path'])
        info['pts_semantic_mask_path'] = osp.join(
            self.data_prefix.get('pts_semantic_mask', ''),
            info['pts_semantic_mask_path'])

        info = super().parse_data_info(info)
151
152
153
        # only be used in `PointSegClassMapping` in pipeline
        # to map original semantic class to valid category ids.
        info['seg_label_mapping'] = self.seg_label_mapping
jshilong's avatar
jshilong committed
154
155
156
        return info

    def parse_ann_info(self, info: dict) -> dict:
157
        """Process the `instances` in data info to `ann_info`.
158
159

        Args:
jshilong's avatar
jshilong committed
160
            info (dict): Info dict.
161
162

        Returns:
163
            dict: Processed `ann_info`.
164
        """
jshilong's avatar
jshilong committed
165
        ann_info = super().parse_ann_info(info)
166
167
        # empty gt
        if ann_info is None:
jshilong's avatar
jshilong committed
168
            ann_info = dict()
169
170
            ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
jshilong's avatar
jshilong committed
171
        # to target box structure
172

jshilong's avatar
jshilong committed
173
174
175
176
177
178
179
        ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
            ann_info['gt_bboxes_3d'],
            box_dim=ann_info['gt_bboxes_3d'].shape[-1],
            with_yaw=False,
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

        return ann_info
180

181
182

@DATASETS.register_module()
ZCMax's avatar
ZCMax committed
183
class ScanNetSegDataset(Seg3DDataset):
184
185
186
187
188
189
190
191
    r"""ScanNet Dataset for Semantic Segmentation Task.

    This class serves as the API for experiments on the ScanNet Dataset.

    Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
    for data downloading.

    Args:
192
193
        data_root (str, optional): Path of dataset root. Defaults to None.
        ann_file (str): Path of annotation file. Defaults to ''.
194
        pipeline (List[dict]): Pipeline used for data processing.
195
196
197
198
            Defaults to [].
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
199
200
201
202
            dict(pts='points',
                 img='',
                 pts_instance_mask='',
                 pts_semantic_mask='').
203
204
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
205
        ignore_index (int, optional): The label index to be ignored, e.g.
206
            unannotated points. If None is given, set to len(self.classes) to
207
            be consistent with PointSegClassMapping function in pipeline.
208
            Defaults to None.
209
        scene_idxs (np.ndarray or str, optional): Precomputed index to load
210
211
            data. For scenes with many points, we may sample it several times.
            Defaults to None.
212
213
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
214
    """
ZCMax's avatar
ZCMax committed
215
    METAINFO = {
216
        'classes':
ZCMax's avatar
ZCMax committed
217
218
219
220
        ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
         'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
         'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
         'otherfurniture'),
221
        'palette': [
ZCMax's avatar
ZCMax committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            [174, 199, 232],
            [152, 223, 138],
            [31, 119, 180],
            [255, 187, 120],
            [188, 189, 34],
            [140, 86, 75],
            [255, 152, 150],
            [214, 39, 40],
            [197, 176, 213],
            [148, 103, 189],
            [196, 156, 148],
            [23, 190, 207],
            [247, 182, 210],
            [219, 219, 141],
            [255, 127, 14],
            [158, 218, 229],
            [44, 160, 44],
            [112, 128, 144],
            [227, 119, 194],
            [82, 84, 163],
        ],
243
244
245
        'seg_valid_class_ids': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
                                24, 28, 33, 34, 36, 39),
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
246
247
        tuple(range(41)),
    }
248
249

    def __init__(self,
ZCMax's avatar
ZCMax committed
250
251
252
253
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
254
255
256
257
                     pts='points',
                     img='',
                     pts_instance_mask='',
                     pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
258
259
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
260
261
262
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
263
                 **kwargs) -> None:
264
265
266
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
ZCMax's avatar
ZCMax committed
267
268
            metainfo=metainfo,
            data_prefix=data_prefix,
269
270
271
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
272
            scene_idxs=scene_idxs,
ZCMax's avatar
ZCMax committed
273
            test_mode=test_mode,
274
            **kwargs)
275

276
277
    def get_scene_idxs(self, scene_idxs: Union[np.ndarray, str,
                                               None]) -> np.ndarray:
278
        """Compute scene_idxs for data sampling.
279

280
        We sample more times for scenes with more points.
281
282
283
284
285
286
        """
        # when testing, we load one whole scene every time
        if not self.test_mode and scene_idxs is None:
            raise NotImplementedError(
                'please provide re-sampled scene indexes for training')

287
        return super().get_scene_idxs(scene_idxs)
288

289
290

@DATASETS.register_module()
ZCMax's avatar
ZCMax committed
291
class ScanNetInstanceSegDataset(Seg3DDataset):
292

ZCMax's avatar
ZCMax committed
293
    METAINFO = {
294
        'classes':
ZCMax's avatar
ZCMax committed
295
296
297
        ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
         'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
         'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
298
        'palette': [
ZCMax's avatar
ZCMax committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            [174, 199, 232],
            [152, 223, 138],
            [31, 119, 180],
            [255, 187, 120],
            [188, 189, 34],
            [140, 86, 75],
            [255, 152, 150],
            [214, 39, 40],
            [197, 176, 213],
            [148, 103, 189],
            [196, 156, 148],
            [23, 190, 207],
            [247, 182, 210],
            [219, 219, 141],
            [255, 127, 14],
            [158, 218, 229],
            [44, 160, 44],
            [112, 128, 144],
            [227, 119, 194],
            [82, 84, 163],
        ],
320
        'seg_valid_class_ids':
ZCMax's avatar
ZCMax committed
321
        (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
322
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
323
324
        tuple(range(41))
    }
325

ZCMax's avatar
ZCMax committed
326
327
328
329
330
    def __init__(self,
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
331
332
333
334
                     pts='points',
                     img='',
                     pts_instance_mask='',
                     pts_semantic_mask=''),
ZCMax's avatar
ZCMax committed
335
336
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
337
338
339
                 test_mode: bool = False,
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
340
                 backend_args: Optional[dict] = None,
ZCMax's avatar
ZCMax committed
341
342
343
344
345
346
347
348
349
350
351
                 **kwargs) -> None:
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            metainfo=metainfo,
            pipeline=pipeline,
            data_prefix=data_prefix,
            modality=modality,
            test_mode=test_mode,
            ignore_index=ignore_index,
            scene_idxs=scene_idxs,
352
            backend_args=backend_args,
ZCMax's avatar
ZCMax committed
353
            **kwargs)