custom_3d.py 15.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import tempfile
3
import warnings
zhangwenwei's avatar
zhangwenwei committed
4
from os import path as osp
5
6
7

import mmcv
import numpy as np
zhangwenwei's avatar
zhangwenwei committed
8
from torch.utils.data import Dataset
9

wuyuefeng's avatar
Demo  
wuyuefeng committed
10
from ..core.bbox import get_box_type
11
from .builder import DATASETS
12
from .pipelines import Compose
13
from .utils import extract_result_dict, get_loading_pipeline
14
15
16


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
17
class Custom3DDataset(Dataset):
zhangwenwei's avatar
zhangwenwei committed
18
    """Customized 3D dataset.
19
20
21
22

    This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI
    dataset.

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    .. code-block:: none

    [
        {'sample_idx':
         'lidar_points': {'lidar_path': velodyne_path,
                           ....
                         },
         'annos': {'box_type_3d':  (str)  'LiDAR/Camera/Depth'
                   'gt_bboxes_3d':  <np.ndarray> (n, 7)
                   'gt_names':  [list]
                   ....
               }
         'calib': { .....}
         'images': { .....}
        }
    ]

40
41
42
43
44
45
46
    Args:
        data_root (str): Path of dataset root.
        ann_file (str): Path of annotation file.
        pipeline (list[dict], optional): Pipeline used for data processing.
            Defaults to None.
        classes (tuple[str], optional): Classes used in the dataset.
            Defaults to None.
wangtai's avatar
wangtai committed
47
        modality (dict, optional): Modality to specify the sensor data used
48
49
50
51
52
            as input. Defaults to None.
        box_type_3d (str, optional): Type of 3D box of this dataset.
            Based on the `box_type_3d`, the dataset will encapsulate the box
            to its original format then converted them to `box_type_3d`.
            Defaults to 'LiDAR'. Available options includes
wangtai's avatar
wangtai committed
53

wangtai's avatar
wangtai committed
54
55
56
            - 'LiDAR': Box in LiDAR coordinates.
            - 'Depth': Box in depth coordinates, usually for indoor dataset.
            - 'Camera': Box in camera coordinates.
57
58
59
60
61
        filter_empty_gt (bool, optional): Whether to filter empty GT.
            Defaults to True.
        test_mode (bool, optional): Whether the dataset is in test mode.
            Defaults to False.
    """
62
63

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
64
                 data_root,
65
66
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
67
                 classes=None,
zhangwenwei's avatar
zhangwenwei committed
68
                 modality=None,
69
                 box_type_3d='LiDAR',
wuyuefeng's avatar
Votenet  
wuyuefeng committed
70
                 filter_empty_gt=True,
71
72
                 test_mode=False,
                 file_client_args=dict(backend='disk')):
73
        super().__init__()
zhangwenwei's avatar
zhangwenwei committed
74
75
        self.data_root = data_root
        self.ann_file = ann_file
76
        self.test_mode = test_mode
zhangwenwei's avatar
zhangwenwei committed
77
        self.modality = modality
wuyuefeng's avatar
Votenet  
wuyuefeng committed
78
        self.filter_empty_gt = filter_empty_gt
wuyuefeng's avatar
Demo  
wuyuefeng committed
79
        self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
zhangwenwei's avatar
zhangwenwei committed
80
81

        self.CLASSES = self.get_classes(classes)
82
        self.file_client = mmcv.FileClient(**file_client_args)
83
        self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
84

85
86
87
88
89
90
91
92
93
94
95
96
97
        # load annotations
        if hasattr(self.file_client, 'get_local_path'):
            with self.file_client.get_local_path(self.ann_file) as local_path:
                self.data_infos = self.load_annotations(open(local_path, 'rb'))
        else:
            warnings.warn(
                'The used MMCV version does not have get_local_path. '
                f'We treat the {self.ann_file} as local paths and it '
                'might cause errors if the path is not a local path. '
                'Please use MMCV>= 1.3.16 if you meet errors.')
            self.data_infos = self.load_annotations(self.ann_file)

        # process pipeline
98
99
100
        if pipeline is not None:
            self.pipeline = Compose(pipeline)

101
        # set group flag for the samplers
zhangwenwei's avatar
zhangwenwei committed
102
103
104
105
        if not self.test_mode:
            self._set_group_flag()

    def load_annotations(self, ann_file):
106
107
108
109
110
111
112
113
        """Load annotations from ann_file.

        Args:
            ann_file (str): Path of the annotation file.

        Returns:
            list[dict]: List of annotations.
        """
114
115
        # loading data from a file-like object needs file format
        return mmcv.load(ann_file, file_format='pkl')
116
117

    def get_data_info(self, index):
118
119
120
121
122
123
        """Get data info according to the given index.

        Args:
            index (int): Index of the sample data to get.

        Returns:
124
            dict: Data information that will be passed to the data
zhangwenwei's avatar
zhangwenwei committed
125
                preprocessing pipelines. It includes the following keys:
126

wangtai's avatar
wangtai committed
127
128
129
130
                - sample_idx (str): Sample index.
                - pts_filename (str): Filename of point clouds.
                - file_name (str): Filename of point clouds.
                - ann_info (dict): Annotation info.
131
        """
132
        info = self.data_infos[index]
133
134
135
        sample_idx = info['sample_idx']
        pts_filename = osp.join(self.data_root,
                                info['lidar_points']['lidar_path'])
136

liyinhao's avatar
liyinhao committed
137
138
139
140
        input_dict = dict(
            pts_filename=pts_filename,
            sample_idx=sample_idx,
            file_name=pts_filename)
141

zhangwenwei's avatar
zhangwenwei committed
142
        if not self.test_mode:
liyinhao's avatar
liyinhao committed
143
            annos = self.get_ann_info(index)
zhangwenwei's avatar
zhangwenwei committed
144
            input_dict['ann_info'] = annos
145
            if self.filter_empty_gt and ~(annos['gt_labels_3d'] != -1).any():
zhangwenwei's avatar
zhangwenwei committed
146
                return None
147
148
        return input_dict

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    def get_ann_info(self, index):
        """Get annotation info according to the given index.

        Args:
            index (int): Index of the annotation data to get.

        Returns:
            dict: Annotation information consists of the following keys:

                - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
                    3D ground truth bboxes
                - gt_labels_3d (np.ndarray): Labels of ground truths.
                - gt_names (list[str]): Class names of ground truths.
        """
        info = self.data_infos[index]
        gt_bboxes_3d = info['annos']['gt_bboxes_3d']
        gt_names_3d = info['annos']['gt_names']
        gt_labels_3d = []
        for cat in gt_names_3d:
            if cat in self.CLASSES:
                gt_labels_3d.append(self.CLASSES.index(cat))
            else:
                gt_labels_3d.append(-1)
        gt_labels_3d = np.array(gt_labels_3d)

        # Obtain original box 3d type in info file
        ori_box_type_3d = info['annos']['box_type_3d']
        ori_box_type_3d, _ = get_box_type(ori_box_type_3d)

        # turn original box type to target box type
        gt_bboxes_3d = ori_box_type_3d(
            gt_bboxes_3d,
            box_dim=gt_bboxes_3d.shape[-1],
            origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

        anns_results = dict(
185
186
187
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
            gt_names=gt_names_3d)
188
189
        return anns_results

zhangwenwei's avatar
zhangwenwei committed
190
    def pre_pipeline(self, results):
191
192
193
        """Initialization before data preparation.

        Args:
194
            results (dict): Dict before data preprocessing.
195

wangtai's avatar
wangtai committed
196
197
198
199
200
201
202
203
204
                - img_fields (list): Image fields.
                - bbox3d_fields (list): 3D bounding boxes fields.
                - pts_mask_fields (list): Mask fields of points.
                - pts_seg_fields (list): Mask fields of point segments.
                - bbox_fields (list): Fields of bounding boxes.
                - mask_fields (list): Fields of masks.
                - seg_fields (list): Segment fields.
                - box_type_3d (str): 3D box type.
                - box_mode_3d (str): 3D box mode.
205
        """
zhangwenwei's avatar
zhangwenwei committed
206
        results['img_fields'] = []
zhangwenwei's avatar
zhangwenwei committed
207
208
209
        results['bbox3d_fields'] = []
        results['pts_mask_fields'] = []
        results['pts_seg_fields'] = []
zhangwenwei's avatar
zhangwenwei committed
210
211
212
        results['bbox_fields'] = []
        results['mask_fields'] = []
        results['seg_fields'] = []
213
214
        results['box_type_3d'] = self.box_type_3d
        results['box_mode_3d'] = self.box_mode_3d
215

liyinhao's avatar
liyinhao committed
216
    def prepare_train_data(self, index):
217
218
219
220
221
222
        """Training data preparation.

        Args:
            index (int): Index for accessing the target data.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
223
            dict: Training data dict of the corresponding index.
224
        """
liyinhao's avatar
liyinhao committed
225
        input_dict = self.get_data_info(index)
226
227
        if input_dict is None:
            return None
zhangwenwei's avatar
zhangwenwei committed
228
        self.pre_pipeline(input_dict)
229
        example = self.pipeline(input_dict)
230
231
232
        if self.filter_empty_gt and \
                (example is None or
                    ~(example['gt_labels_3d']._data != -1).any()):
233
234
235
            return None
        return example

236
    def prepare_test_data(self, index):
237
238
239
240
241
242
        """Prepare data for testing.

        Args:
            index (int): Index for accessing the target data.

        Returns:
zhangwenwei's avatar
zhangwenwei committed
243
            dict: Testing data dict of the corresponding index.
244
        """
245
        input_dict = self.get_data_info(index)
zhangwenwei's avatar
zhangwenwei committed
246
        self.pre_pipeline(input_dict)
247
248
        example = self.pipeline(input_dict)
        return example
249

liyinhao's avatar
liyinhao committed
250
251
    @classmethod
    def get_classes(cls, classes=None):
252
253
        """Get class names of current dataset.

liyinhao's avatar
liyinhao committed
254
        Args:
255
            classes (Sequence[str] | str): If classes is None, use
liyinhao's avatar
liyinhao committed
256
257
258
259
                default CLASSES defined by builtin dataset. If classes is a
                string, take it as a file name. The file contains the name of
                classes where each line contains one class name. If classes is
                a tuple or list, override the CLASSES defined by the dataset.
zhangwenwei's avatar
zhangwenwei committed
260
261

        Return:
wangtai's avatar
wangtai committed
262
            list[str]: A list of class names.
liyinhao's avatar
liyinhao committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        """
        if classes is None:
            return cls.CLASSES

        if isinstance(classes, str):
            # take it as a file path
            class_names = mmcv.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

liyinhao's avatar
liyinhao committed
277
278
279
280
    def format_results(self,
                       outputs,
                       pklfile_prefix=None,
                       submission_prefix=None):
281
282
283
284
        """Format the results to pkl file.

        Args:
            outputs (list[dict]): Testing results of the dataset.
285
            pklfile_prefix (str): The prefix of pkl files. It includes
286
287
288
289
                the file path and the prefix of filename, e.g., "a/b/prefix".
                If not specified, a temp file will be created. Default: None.

        Returns:
290
291
            tuple: (outputs, tmp_dir), outputs is the detection results,
                tmp_dir is the temporal directory created for saving json
zhangwenwei's avatar
zhangwenwei committed
292
                files when ``jsonfile_prefix`` is not specified.
293
        """
liyinhao's avatar
liyinhao committed
294
295
296
297
298
299
        if pklfile_prefix is None:
            tmp_dir = tempfile.TemporaryDirectory()
            pklfile_prefix = osp.join(tmp_dir.name, 'results')
            out = f'{pklfile_prefix}.pkl'
        mmcv.dump(outputs, out)
        return outputs, tmp_dir
300

liyinhao's avatar
liyinhao committed
301
302
303
304
305
306
    def evaluate(self,
                 results,
                 metric=None,
                 iou_thr=(0.25, 0.5),
                 logger=None,
                 show=False,
307
308
                 out_dir=None,
                 pipeline=None):
309
310
311
312
313
        """Evaluate.

        Evaluation in indoor protocol.

        Args:
liyinhao's avatar
liyinhao committed
314
            results (list[dict]): List of results.
315
316
317
318
319
320
            metric (str | list[str], optional): Metrics to be evaluated.
                Defaults to None.
            iou_thr (list[float]): AP IoU thresholds. Defaults to (0.25, 0.5).
            logger (logging.Logger | str, optional): Logger used for printing
                related information during evaluation. Defaults to None.
            show (bool, optional): Whether to visualize.
liyinhao's avatar
liyinhao committed
321
                Default: False.
322
            out_dir (str, optional): Path to save the visualization results.
liyinhao's avatar
liyinhao committed
323
                Default: None.
324
325
            pipeline (list[dict], optional): raw data loading for showing.
                Default: None.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
326

liyinhao's avatar
liyinhao committed
327
328
        Returns:
            dict: Evaluation results.
329
330
        """
        from mmdet3d.core.evaluation import indoor_eval
liyinhao's avatar
liyinhao committed
331
332
        assert isinstance(
            results, list), f'Expect results to be list, got {type(results)}.'
zhangwenwei's avatar
zhangwenwei committed
333
        assert len(results) > 0, 'Expect length of results > 0.'
wuyuefeng's avatar
Votenet  
wuyuefeng committed
334
        assert len(results) == len(self.data_infos)
liyinhao's avatar
liyinhao committed
335
336
337
        assert isinstance(
            results[0], dict
        ), f'Expect elements in results to be dict, got {type(results[0])}.'
338
        gt_annos = [info['annos'] for info in self.data_infos]
zhangwenwei's avatar
zhangwenwei committed
339
        label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
zhangwenwei's avatar
zhangwenwei committed
340
        ret_dict = indoor_eval(
wuyuefeng's avatar
wuyuefeng committed
341
342
343
344
345
346
347
            gt_annos,
            results,
            iou_thr,
            label2cat,
            logger=logger,
            box_type_3d=self.box_type_3d,
            box_mode_3d=self.box_mode_3d)
liyinhao's avatar
liyinhao committed
348
        if show:
349
            self.show(results, out_dir, pipeline=pipeline)
wuyuefeng's avatar
wuyuefeng committed
350

liyinhao's avatar
liyinhao committed
351
        return ret_dict
zhangwenwei's avatar
zhangwenwei committed
352

353
354
355
356
357
358
359
360
361
    def _build_default_pipeline(self):
        """Build the default pipeline for this dataset."""
        raise NotImplementedError('_build_default_pipeline is not implemented '
                                  f'for dataset {self.__class__.__name__}')

    def _get_pipeline(self, pipeline):
        """Get data loading pipeline in self.show/evaluate function.

        Args:
362
            pipeline (list[dict]): Input pipeline. If None is given,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                get from self.pipeline.
        """
        if pipeline is None:
            if not hasattr(self, 'pipeline') or self.pipeline is None:
                warnings.warn(
                    'Use default pipeline for data loading, this may cause '
                    'errors when data is on ceph')
                return self._build_default_pipeline()
            loading_pipeline = get_loading_pipeline(self.pipeline.transforms)
            return Compose(loading_pipeline)
        return Compose(pipeline)

    def _extract_data(self, index, pipeline, key, load_annos=False):
        """Load data using input pipeline and extract data according to key.

        Args:
            index (int): Index for accessing the target data.
            pipeline (:obj:`Compose`): Composed data loading pipeline.
            key (str | list[str]): One single or a list of data key.
            load_annos (bool): Whether to load data annotations.
                If True, need to set self.test_mode as False before loading.

        Returns:
            np.ndarray | torch.Tensor | list[np.ndarray | torch.Tensor]:
                A single or a list of loaded data.
        """
        assert pipeline is not None, 'data loading pipeline is not provided'
        # when we want to load ground-truth via pipeline (e.g. bbox, seg mask)
        # we need to set self.test_mode as False so that we have 'annos'
        if load_annos:
            original_test_mode = self.test_mode
            self.test_mode = False
        input_dict = self.get_data_info(index)
        self.pre_pipeline(input_dict)
        example = pipeline(input_dict)

        # extract data items according to keys
        if isinstance(key, str):
401
            data = extract_result_dict(example, key)
402
        else:
403
            data = [extract_result_dict(example, k) for k in key]
404
405
406
407
408
        if load_annos:
            self.test_mode = original_test_mode

        return data

zhangwenwei's avatar
zhangwenwei committed
409
    def __len__(self):
410
411
412
413
414
        """Return the length of data infos.

        Returns:
            int: Length of data infos.
        """
zhangwenwei's avatar
zhangwenwei committed
415
416
417
        return len(self.data_infos)

    def _rand_another(self, idx):
418
419
420
421
422
        """Randomly get another item with the same flag.

        Returns:
            int: Another index of item with the same flag.
        """
zhangwenwei's avatar
zhangwenwei committed
423
424
425
426
        pool = np.where(self.flag == self.flag[idx])[0]
        return np.random.choice(pool)

    def __getitem__(self, idx):
427
428
429
430
431
        """Get item from infos according to the given index.

        Returns:
            dict: Data dictionary of the corresponding index.
        """
zhangwenwei's avatar
zhangwenwei committed
432
433
434
435
436
437
438
439
440
441
442
443
444
        if self.test_mode:
            return self.prepare_test_data(idx)
        while True:
            data = self.prepare_train_data(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
445
446
        otherwise group 0. In 3D datasets, they are all the same, thus are all
        zeros.
zhangwenwei's avatar
zhangwenwei committed
447
448
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)