scannet_data_utils.py 5.55 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import concurrent.futures as futures
liyinhao's avatar
liyinhao committed
2
import os.path as osp
3

4
import mmcv
5
6
7
import numpy as np


liyinhao's avatar
liyinhao committed
8
class ScanNetData(object):
liyinhao's avatar
liyinhao committed
9
    """ScanNet data.
liyinhao's avatar
liyinhao committed
10

liyinhao's avatar
liyinhao committed
11
    Generate scannet infos for scannet_converter.
liyinhao's avatar
liyinhao committed
12
13

    Args:
liyinhao's avatar
liyinhao committed
14
        root_path (str): Root path of the raw data.
liyinhao's avatar
liyinhao committed
15
        split (str): Set split type of the data. Default: 'train'.
liyinhao's avatar
liyinhao committed
16
    """
17
18
19
20

    def __init__(self, root_path, split='train'):
        self.root_dir = root_path
        self.split = split
liyinhao's avatar
liyinhao committed
21
        self.split_dir = osp.join(root_path)
22
23
24
25
26
27
28
29
30
        self.classes = [
            'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
            'bookshelf', 'picture', 'counter', 'desk', 'curtain',
            'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
            'garbagebin'
        ]
        self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
        self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
        self.cat_ids = np.array(
31
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
32
        self.cat_ids2class = {
33
            nyu40id: i
34
            for i, nyu40id in enumerate(list(self.cat_ids))
35
36
        }
        assert split in ['train', 'val', 'test']
liyinhao's avatar
liyinhao committed
37
38
        split_file = osp.join(self.root_dir, 'meta_data',
                              f'scannetv2_{split}.txt')
39
40
        mmcv.check_file_exist(split_file)
        self.sample_id_list = mmcv.list_from_file(split_file)
41
42
43
44
45

    def __len__(self):
        return len(self.sample_id_list)

    def get_box_label(self, idx):
liyinhao's avatar
liyinhao committed
46
47
48
        box_file = osp.join(self.root_dir, 'scannet_train_instance_data',
                            f'{idx}_bbox.npy')
        mmcv.check_file_exist(box_file)
49
50
        return np.load(box_file)

liyinhao's avatar
liyinhao committed
51
    def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
liyinhao's avatar
liyinhao committed
52
        """Get data infos.
liyinhao's avatar
liyinhao committed
53
54
55
56
57
58

        This method gets information from the raw data.

        Args:
            num_workers (int): Number of threads to be used. Default: 4.
            has_label (bool): Whether the data has label. Default: True.
liyinhao's avatar
liyinhao committed
59
            sample_id_list (list[int]): Index list of the sample.
liyinhao's avatar
liyinhao committed
60
                Default: None.
liyinhao's avatar
liyinhao committed
61
62

        Returns:
liyinhao's avatar
liyinhao committed
63
            infos (list[dict]): Information of the raw data.
liyinhao's avatar
liyinhao committed
64
        """
65
66

        def process_single_scene(sample_idx):
liyinhao's avatar
liyinhao committed
67
            print(f'{self.split} sample_idx: {sample_idx}')
68
69
70
            info = dict()
            pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
            info['point_cloud'] = pc_info
liyinhao's avatar
liyinhao committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            pts_filename = osp.join(self.root_dir,
                                    'scannet_train_instance_data',
                                    f'{sample_idx}_vert.npy')
            pts_instance_mask_path = osp.join(self.root_dir,
                                              'scannet_train_instance_data',
                                              f'{sample_idx}_ins_label.npy')
            pts_semantic_mask_path = osp.join(self.root_dir,
                                              'scannet_train_instance_data',
                                              f'{sample_idx}_sem_label.npy')

            points = np.load(pts_filename)
            pts_instance_mask = np.load(pts_instance_mask_path).astype(np.long)
            pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.long)

            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points'))
            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask'))
            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask'))

            points.tofile(
                osp.join(self.root_dir, 'points', f'{sample_idx}.bin'))
            pts_instance_mask.tofile(
                osp.join(self.root_dir, 'instance_mask', f'{sample_idx}.bin'))
            pts_semantic_mask.tofile(
                osp.join(self.root_dir, 'semantic_mask', f'{sample_idx}.bin'))

            info['pts_path'] = osp.join('points', f'{sample_idx}.bin')
            info['pts_instance_mask_path'] = osp.join('instance_mask',
                                                      f'{sample_idx}.bin')
            info['pts_semantic_mask_path'] = osp.join('semantic_mask',
                                                      f'{sample_idx}.bin')
101
102
103
104
105
106
107
108
109
110

            if has_label:
                annotations = {}
                boxes_with_classes = self.get_box_label(
                    sample_idx)  # k, 6 + class
                annotations['gt_num'] = boxes_with_classes.shape[0]
                if annotations['gt_num'] != 0:
                    minmax_boxes3d = boxes_with_classes[:, :-1]  # k, 6
                    classes = boxes_with_classes[:, -1]  # k, 1
                    annotations['name'] = np.array([
111
                        self.label2cat[self.cat_ids2class[classes[i]]]
112
113
114
115
116
117
118
119
                        for i in range(annotations['gt_num'])
                    ])
                    annotations['location'] = minmax_boxes3d[:, :3]
                    annotations['dimensions'] = minmax_boxes3d[:, 3:6]
                    annotations['gt_boxes_upright_depth'] = minmax_boxes3d
                    annotations['index'] = np.arange(
                        annotations['gt_num'], dtype=np.int32)
                    annotations['class'] = np.array([
120
                        self.cat_ids2class[classes[i]]
121
122
123
124
125
126
127
128
129
130
                        for i in range(annotations['gt_num'])
                    ])
                info['annos'] = annotations
            return info

        sample_id_list = sample_id_list if sample_id_list is not None \
            else self.sample_id_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)