scannet_dataset.py 12.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
zhangwenwei's avatar
zhangwenwei committed
3
from os import path as osp
ZCMax's avatar
ZCMax committed
4
from typing import Callable, List, Optional, Union
5

6
7
import numpy as np

8
from mmdet3d.registry import DATASETS
zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures import DepthInstance3DBoxes
jshilong's avatar
jshilong committed
10
from .det3d_dataset import Det3DDataset
ZCMax's avatar
ZCMax committed
11
from .seg3d_dataset import Seg3DDataset
12
13
14


@DATASETS.register_module()
jshilong's avatar
jshilong committed
15
class ScanNetDataset(Det3DDataset):
16
    r"""ScanNet Dataset for Detection Task.
17

wangtai's avatar
wangtai committed
18
19
    This class serves as the API for experiments on the ScanNet Dataset.

zhangwenwei's avatar
zhangwenwei committed
20
21
    Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
    for data downloading.
wangtai's avatar
wangtai committed
22
23
24
25

    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
jshilong's avatar
jshilong committed
26
27
28
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for data. Defaults to
29
30
31
            dict(pts='points',
                 pts_isntance_mask='instance_mask',
                 pts_semantic_mask='semantic_mask').
jshilong's avatar
jshilong committed
32
        pipeline (list[dict]): Pipeline used for data processing.
33
34
35
            Defaults to [].
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_camera=False, use_lidar=True).
jshilong's avatar
jshilong committed
36
        box_type_3d (str): Type of 3D box of this dataset.
wangtai's avatar
wangtai committed
37
38
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
39
            Defaults to 'Depth' in this dataset. Available options includes:
wangtai's avatar
wangtai committed
40

wangtai's avatar
wangtai committed
41
42
43
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
44
45
46
47
        filter_empty_gt (bool): Whether to filter the data with empty GT.
            If it's set to be True, the example with empty annotations after
            data pipeline will be dropped and a random example will be chosen
            in `__getitem__`. Defaults to True.
jshilong's avatar
jshilong committed
48
        test_mode (bool): Whether the dataset is in test mode.
wangtai's avatar
wangtai committed
49
50
            Defaults to False.
    """
jshilong's avatar
jshilong committed
51
52
53
54
    METAINFO = {
        'CLASSES':
        ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
         'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
55
56
57
58
59
60
         'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
        # the valid ids of segmentation annotations
        'seg_valid_class_ids':
        (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
        'seg_all_class_ids':
        tuple(range(1, 41))
jshilong's avatar
jshilong committed
61
    }
62
63

    def __init__(self,
jshilong's avatar
jshilong committed
64
65
                 data_root: str,
                 ann_file: str,
66
                 metainfo: Optional[dict] = None,
jshilong's avatar
jshilong committed
67
68
                 data_prefix: dict = dict(
                     pts='points',
69
                     pts_instance_mask='instance_mask',
jshilong's avatar
jshilong committed
70
71
                     pts_semantic_mask='semantic_mask'),
                 pipeline: List[Union[dict, Callable]] = [],
72
                 modality: dict = dict(use_camera=False, use_lidar=True),
jshilong's avatar
jshilong committed
73
74
75
                 box_type_3d: str = 'Depth',
                 filter_empty_gt: bool = True,
                 test_mode: bool = False,
76
                 **kwargs) -> None:
77
78
79
80
81
82
83
84
85
86
87

        # construct seg_label_mapping for semantic mask
        seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
        seg_valid_cat_ids = self.METAINFO['seg_valid_class_ids']
        neg_label = len(seg_valid_cat_ids)
        seg_label_mapping = np.ones(
            seg_max_cat_id + 1, dtype=np.int) * neg_label
        for cls_idx, cat_id in enumerate(seg_valid_cat_ids):
            seg_label_mapping[cat_id] = cls_idx
        self.seg_label_mapping = seg_label_mapping

88
89
90
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
jshilong's avatar
jshilong committed
91
92
            metainfo=metainfo,
            data_prefix=data_prefix,
93
94
95
96
            pipeline=pipeline,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
97
98
            test_mode=test_mode,
            **kwargs)
99
100

        self.metainfo['seg_label_mapping'] = self.seg_label_mapping
101
        assert 'use_camera' in self.modality and \
jshilong's avatar
jshilong committed
102
103
               'use_lidar' in self.modality
        assert self.modality['use_camera'] or self.modality['use_lidar']
104

jshilong's avatar
jshilong committed
105
    @staticmethod
106
    def _get_axis_align_matrix(info: dict) -> np.ndarray:
jshilong's avatar
jshilong committed
107
        """Get axis_align_matrix from info. If not exist, return identity mat.
108
109

        Args:
jshilong's avatar
jshilong committed
110
            info (dict): Info of a single sample data.
111
112

        Returns:
jshilong's avatar
jshilong committed
113
            np.ndarray: 4x4 transformation matrix.
114
        """
jshilong's avatar
jshilong committed
115
116
        if 'axis_align_matrix' in info:
            return np.array(info['axis_align_matrix'])
117
        else:
jshilong's avatar
jshilong committed
118
119
120
121
            warnings.warn(
                'axis_align_matrix is not found in ScanNet data info, please '
                'use new pre-process scripts to re-generate ScanNet data')
            return np.eye(4).astype(np.float32)
liyinhao's avatar
liyinhao committed
122

jshilong's avatar
jshilong committed
123
124
    def parse_data_info(self, info: dict) -> dict:
        """Process the raw data info.
125

jshilong's avatar
jshilong committed
126
127
        The only difference with it in `Det3DDataset`
        is the specific process for `axis_align_matrix'.
128
129

        Args:
jshilong's avatar
jshilong committed
130
            info (dict): Raw info dict.
131
132

        Returns:
133
134
            dict: Has `ann_info` in training stage. And
            all path has been converted to absolute path.
135
        """
jshilong's avatar
jshilong committed
136
137
138
139
140
141
142
143
144
        info['axis_align_matrix'] = self._get_axis_align_matrix(info)
        info['pts_instance_mask_path'] = osp.join(
            self.data_prefix.get('pts_instance_mask', ''),
            info['pts_instance_mask_path'])
        info['pts_semantic_mask_path'] = osp.join(
            self.data_prefix.get('pts_semantic_mask', ''),
            info['pts_semantic_mask_path'])

        info = super().parse_data_info(info)
145
146
147
        # only be used in `PointSegClassMapping` in pipeline
        # to map original semantic class to valid category ids.
        info['seg_label_mapping'] = self.seg_label_mapping
jshilong's avatar
jshilong committed
148
149
150
        return info

    def parse_ann_info(self, info: dict) -> dict:
151
        """Process the `instances` in data info to `ann_info`.
152
153

        Args:
jshilong's avatar
jshilong committed
154
            info (dict): Info dict.
155
156

        Returns:
157
            dict: Processed `ann_info`.
158
        """
jshilong's avatar
jshilong committed
159
        ann_info = super().parse_ann_info(info)
160
161
        # empty gt
        if ann_info is None:
jshilong's avatar
jshilong committed
162
            ann_info = dict()
163
164
            ann_info['gt_bboxes_3d'] = np.zeros((0, 6), dtype=np.float32)
            ann_info['gt_labels_3d'] = np.zeros((0, ), dtype=np.int64)
jshilong's avatar
jshilong committed
165
        # to target box structure
166

jshilong's avatar
jshilong committed
167
168
169
170
171
172
173
        ann_info['gt_bboxes_3d'] = DepthInstance3DBoxes(
            ann_info['gt_bboxes_3d'],
            box_dim=ann_info['gt_bboxes_3d'].shape[-1],
            with_yaw=False,
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

        return ann_info
174

175
176

@DATASETS.register_module()
ZCMax's avatar
ZCMax committed
177
class ScanNetSegDataset(Seg3DDataset):
178
179
180
181
182
183
184
185
    r"""ScanNet Dataset for Semantic Segmentation Task.

    This class serves as the API for experiments on the ScanNet Dataset.

    Please refer to the `github repo <https://github.com/ScanNet/ScanNet>`_
    for data downloading.

    Args:
186
187
188
189
190
191
192
193
194
195
        data_root (str, optional): Path of dataset root. Defaults to None.
        ann_file (str): Path of annotation file. Defaults to ''.
        pipeline (list[dict]): Pipeline used for data processing.
            Defaults to [].
        metainfo (dict, optional): Meta information for dataset, such as class
            information. Defaults to None.
        data_prefix (dict): Prefix for training data. Defaults to
            dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
        modality (dict): Modality to specify the sensor data used as input.
            Defaults to dict(use_lidar=True, use_camera=False).
196
        ignore_index (int, optional): The label index to be ignored, e.g.
197
198
            unannotated points. If None is given, set to len(self.CLASSES) to
            be consistent with PointSegClassMapping function in pipeline.
199
200
201
202
            Defaults to None.
        scene_idxs (np.ndarray | str, optional): Precomputed index to load
            data. For scenes with many points, we may sample it several times.
            Defaults to None.
203
204
        test_mode (bool): Whether the dataset is in test mode.
            Defaults to False.
205
    """
ZCMax's avatar
ZCMax committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    METAINFO = {
        'CLASSES':
        ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
         'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
         'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
         'otherfurniture'),
        'PALETTE': [
            [174, 199, 232],
            [152, 223, 138],
            [31, 119, 180],
            [255, 187, 120],
            [188, 189, 34],
            [140, 86, 75],
            [255, 152, 150],
            [214, 39, 40],
            [197, 176, 213],
            [148, 103, 189],
            [196, 156, 148],
            [23, 190, 207],
            [247, 182, 210],
            [219, 219, 141],
            [255, 127, 14],
            [158, 218, 229],
            [44, 160, 44],
            [112, 128, 144],
            [227, 119, 194],
            [82, 84, 163],
        ],
234
235
236
        'seg_valid_class_ids': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
                                24, 28, 33, 34, 36, 39),
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
237
238
        tuple(range(41)),
    }
239
240

    def __init__(self,
ZCMax's avatar
ZCMax committed
241
242
243
244
245
246
247
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
                     pts='points', img='', instance_mask='', semantic_mask=''),
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
248
249
250
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
                 test_mode: bool = False,
ZCMax's avatar
ZCMax committed
251
                 **kwargs) -> None:
252
253
254
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
ZCMax's avatar
ZCMax committed
255
256
            metainfo=metainfo,
            data_prefix=data_prefix,
257
258
259
            pipeline=pipeline,
            modality=modality,
            ignore_index=ignore_index,
260
            scene_idxs=scene_idxs,
ZCMax's avatar
ZCMax committed
261
            test_mode=test_mode,
262
            **kwargs)
263

264
265
    def get_scene_idxs(self, scene_idxs):
        """Compute scene_idxs for data sampling.
266

267
        We sample more times for scenes with more points.
268
269
270
271
272
273
        """
        # when testing, we load one whole scene every time
        if not self.test_mode and scene_idxs is None:
            raise NotImplementedError(
                'please provide re-sampled scene indexes for training')

274
        return super().get_scene_idxs(scene_idxs)
275

276
277

@DATASETS.register_module()
ZCMax's avatar
ZCMax committed
278
class ScanNetInstanceSegDataset(Seg3DDataset):
279

ZCMax's avatar
ZCMax committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    METAINFO = {
        'CLASSES':
        ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
         'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
         'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
        'PLATTE': [
            [174, 199, 232],
            [152, 223, 138],
            [31, 119, 180],
            [255, 187, 120],
            [188, 189, 34],
            [140, 86, 75],
            [255, 152, 150],
            [214, 39, 40],
            [197, 176, 213],
            [148, 103, 189],
            [196, 156, 148],
            [23, 190, 207],
            [247, 182, 210],
            [219, 219, 141],
            [255, 127, 14],
            [158, 218, 229],
            [44, 160, 44],
            [112, 128, 144],
            [227, 119, 194],
            [82, 84, 163],
        ],
307
        'seg_valid_class_ids':
ZCMax's avatar
ZCMax committed
308
        (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39),
309
        'seg_all_class_ids':
ZCMax's avatar
ZCMax committed
310
311
        tuple(range(41))
    }
312

ZCMax's avatar
ZCMax committed
313
314
315
316
317
318
319
320
    def __init__(self,
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = dict(
                     pts='points', img='', instance_mask='', semantic_mask=''),
                 pipeline: List[Union[dict, Callable]] = [],
                 modality: dict = dict(use_lidar=True, use_camera=False),
321
322
323
324
                 test_mode: bool = False,
                 ignore_index: Optional[int] = None,
                 scene_idxs: Optional[Union[np.ndarray, str]] = None,
                 file_client_args: dict = dict(backend='disk'),
ZCMax's avatar
ZCMax committed
325
326
327
328
329
330
331
332
333
334
335
336
337
                 **kwargs) -> None:
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            metainfo=metainfo,
            pipeline=pipeline,
            data_prefix=data_prefix,
            modality=modality,
            test_mode=test_mode,
            ignore_index=ignore_index,
            scene_idxs=scene_idxs,
            file_client_args=file_client_args,
            **kwargs)