"tools/vscode:/vscode.git/clone" did not exist on "a7a32a9e5e0612a9a36c9663a03b8d3e59b924e9"
indoor_converter.py 4.96 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import os

zhangwenwei's avatar
zhangwenwei committed
4
import mmcv
5
import numpy as np
liyinhao's avatar
liyinhao committed
6

7
8
9
10
from tools.dataset_converters.s3dis_data_utils import S3DISData, S3DISSegData
from tools.dataset_converters.scannet_data_utils import (ScanNetData,
                                                         ScanNetSegData)
from tools.dataset_converters.sunrgbd_data_utils import SUNRGBDData
liyinhao's avatar
liyinhao committed
11
12
13
14
15


def create_indoor_info_file(data_path,
                            pkl_prefix='sunrgbd',
                            save_path=None,
liyinhao's avatar
liyinhao committed
16
17
                            use_v1=False,
                            workers=4):
liyinhao's avatar
liyinhao committed
18
    """Create indoor information file.
liyinhao's avatar
liyinhao committed
19
20
21
22
23

    Get information of the raw data and save it to the pkl file.

    Args:
        data_path (str): Path of the data.
24
25
26
27
28
        pkl_prefix (str, optional): Prefix of the pkl to be saved.
            Default: 'sunrgbd'.
        save_path (str, optional): Path of the pkl to be saved. Default: None.
        use_v1 (bool, optional): Whether to use v1. Default: False.
        workers (int, optional): Number of threads to be used. Default: 4.
liyinhao's avatar
liyinhao committed
29
30
    """
    assert os.path.exists(data_path)
31
32
    assert pkl_prefix in ['sunrgbd', 'scannet', 's3dis'], \
        f'unsupported indoor dataset {pkl_prefix}'
33
    save_path = data_path if save_path is None else save_path
liyinhao's avatar
liyinhao committed
34
    assert os.path.exists(save_path)
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    # generate infos for both detection and segmentation task
    if pkl_prefix in ['sunrgbd', 'scannet']:
        train_filename = os.path.join(save_path,
                                      f'{pkl_prefix}_infos_train.pkl')
        val_filename = os.path.join(save_path, f'{pkl_prefix}_infos_val.pkl')
        if pkl_prefix == 'sunrgbd':
            # SUN RGB-D has a train-val split
            train_dataset = SUNRGBDData(
                root_path=data_path, split='train', use_v1=use_v1)
            val_dataset = SUNRGBDData(
                root_path=data_path, split='val', use_v1=use_v1)
        else:
            # ScanNet has a train-val-test split
            train_dataset = ScanNetData(root_path=data_path, split='train')
            val_dataset = ScanNetData(root_path=data_path, split='val')
            test_dataset = ScanNetData(root_path=data_path, split='test')
            test_filename = os.path.join(save_path,
                                         f'{pkl_prefix}_infos_test.pkl')
54

55
56
57
58
        infos_train = train_dataset.get_infos(
            num_workers=workers, has_label=True)
        mmcv.dump(infos_train, train_filename, 'pkl')
        print(f'{pkl_prefix} info train file is saved to {train_filename}')
59

60
61
62
        infos_val = val_dataset.get_infos(num_workers=workers, has_label=True)
        mmcv.dump(infos_val, val_filename, 'pkl')
        print(f'{pkl_prefix} info val file is saved to {val_filename}')
63

64
65
66
67
68
69
    if pkl_prefix == 'scannet':
        infos_test = test_dataset.get_infos(
            num_workers=workers, has_label=False)
        mmcv.dump(infos_test, test_filename, 'pkl')
        print(f'{pkl_prefix} info test file is saved to {test_filename}')

70
71
    # generate infos for the semantic segmentation task
    # e.g. re-sampled scene indexes and label weights
72
73
    # scene indexes are used to re-sample rooms with different number of points
    # label weights are used to balance classes with different number of points
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    if pkl_prefix == 'scannet':
        # label weight computation function is adopted from
        # https://github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py#L24
        train_dataset = ScanNetSegData(
            data_root=data_path,
            ann_file=train_filename,
            split='train',
            num_points=8192,
            label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
        # TODO: do we need to generate on val set?
        val_dataset = ScanNetSegData(
            data_root=data_path,
            ann_file=val_filename,
            split='val',
            num_points=8192,
            label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
90
        # no need to generate for test set
91
92
        train_dataset.get_seg_infos()
        val_dataset.get_seg_infos()
Danila Rukhovich's avatar
Danila Rukhovich committed
93
    elif pkl_prefix == 's3dis':
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        # S3DIS doesn't have a fixed train-val split
        # it has 6 areas instead, so we generate info file for each of them
        # in training, we will use dataset to wrap different areas
        splits = [f'Area_{i}' for i in [1, 2, 3, 4, 5, 6]]
        for split in splits:
            dataset = S3DISData(root_path=data_path, split=split)
            info = dataset.get_infos(num_workers=workers, has_label=True)
            filename = os.path.join(save_path,
                                    f'{pkl_prefix}_infos_{split}.pkl')
            mmcv.dump(info, filename, 'pkl')
            print(f'{pkl_prefix} info {split} file is saved to {filename}')
            seg_dataset = S3DISSegData(
                data_root=data_path,
                ann_file=filename,
                split=split,
                num_points=4096,
                label_weight_func=lambda x: 1.0 / np.log(1.2 + x))
            seg_dataset.get_seg_infos()