scannet_data_utils.py 5.56 KB
Newer Older
1
import mmcv
2
import numpy as np
zhangwenwei's avatar
zhangwenwei committed
3
4
from concurrent import futures as futures
from os import path as osp
5
6


liyinhao's avatar
liyinhao committed
7
class ScanNetData(object):
liyinhao's avatar
liyinhao committed
8
    """ScanNet data.
liyinhao's avatar
liyinhao committed
9

liyinhao's avatar
liyinhao committed
10
    Generate scannet infos for scannet_converter.
liyinhao's avatar
liyinhao committed
11
12

    Args:
liyinhao's avatar
liyinhao committed
13
        root_path (str): Root path of the raw data.
liyinhao's avatar
liyinhao committed
14
        split (str): Set split type of the data. Default: 'train'.
liyinhao's avatar
liyinhao committed
15
    """
16
17
18
19

    def __init__(self, root_path, split='train'):
        self.root_dir = root_path
        self.split = split
liyinhao's avatar
liyinhao committed
20
        self.split_dir = osp.join(root_path)
21
22
23
24
25
26
27
28
29
        self.classes = [
            'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
            'bookshelf', 'picture', 'counter', 'desk', 'curtain',
            'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
            'garbagebin'
        ]
        self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
        self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
        self.cat_ids = np.array(
30
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
31
        self.cat_ids2class = {
32
            nyu40id: i
33
            for i, nyu40id in enumerate(list(self.cat_ids))
34
35
        }
        assert split in ['train', 'val', 'test']
liyinhao's avatar
liyinhao committed
36
37
        split_file = osp.join(self.root_dir, 'meta_data',
                              f'scannetv2_{split}.txt')
38
39
        mmcv.check_file_exist(split_file)
        self.sample_id_list = mmcv.list_from_file(split_file)
40
41
42
43
44

    def __len__(self):
        return len(self.sample_id_list)

    def get_box_label(self, idx):
liyinhao's avatar
liyinhao committed
45
46
47
        box_file = osp.join(self.root_dir, 'scannet_train_instance_data',
                            f'{idx}_bbox.npy')
        mmcv.check_file_exist(box_file)
48
49
        return np.load(box_file)

liyinhao's avatar
liyinhao committed
50
    def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
liyinhao's avatar
liyinhao committed
51
        """Get data infos.
liyinhao's avatar
liyinhao committed
52
53
54
55
56
57

        This method gets information from the raw data.

        Args:
            num_workers (int): Number of threads to be used. Default: 4.
            has_label (bool): Whether the data has label. Default: True.
liyinhao's avatar
liyinhao committed
58
            sample_id_list (list[int]): Index list of the sample.
liyinhao's avatar
liyinhao committed
59
                Default: None.
liyinhao's avatar
liyinhao committed
60
61

        Returns:
liyinhao's avatar
liyinhao committed
62
            infos (list[dict]): Information of the raw data.
liyinhao's avatar
liyinhao committed
63
        """
64
65

        def process_single_scene(sample_idx):
liyinhao's avatar
liyinhao committed
66
            print(f'{self.split} sample_idx: {sample_idx}')
67
68
69
            info = dict()
            pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
            info['point_cloud'] = pc_info
liyinhao's avatar
liyinhao committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            pts_filename = osp.join(self.root_dir,
                                    'scannet_train_instance_data',
                                    f'{sample_idx}_vert.npy')
            pts_instance_mask_path = osp.join(self.root_dir,
                                              'scannet_train_instance_data',
                                              f'{sample_idx}_ins_label.npy')
            pts_semantic_mask_path = osp.join(self.root_dir,
                                              'scannet_train_instance_data',
                                              f'{sample_idx}_sem_label.npy')

            points = np.load(pts_filename)
            pts_instance_mask = np.load(pts_instance_mask_path).astype(np.long)
            pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.long)

            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'points'))
            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'instance_mask'))
            mmcv.mkdir_or_exist(osp.join(self.root_dir, 'semantic_mask'))

            points.tofile(
                osp.join(self.root_dir, 'points', f'{sample_idx}.bin'))
            pts_instance_mask.tofile(
                osp.join(self.root_dir, 'instance_mask', f'{sample_idx}.bin'))
            pts_semantic_mask.tofile(
                osp.join(self.root_dir, 'semantic_mask', f'{sample_idx}.bin'))

            info['pts_path'] = osp.join('points', f'{sample_idx}.bin')
            info['pts_instance_mask_path'] = osp.join('instance_mask',
                                                      f'{sample_idx}.bin')
            info['pts_semantic_mask_path'] = osp.join('semantic_mask',
                                                      f'{sample_idx}.bin')
100
101
102
103
104
105
106
107
108
109

            if has_label:
                annotations = {}
                boxes_with_classes = self.get_box_label(
                    sample_idx)  # k, 6 + class
                annotations['gt_num'] = boxes_with_classes.shape[0]
                if annotations['gt_num'] != 0:
                    minmax_boxes3d = boxes_with_classes[:, :-1]  # k, 6
                    classes = boxes_with_classes[:, -1]  # k, 1
                    annotations['name'] = np.array([
110
                        self.label2cat[self.cat_ids2class[classes[i]]]
111
112
113
114
115
116
117
118
                        for i in range(annotations['gt_num'])
                    ])
                    annotations['location'] = minmax_boxes3d[:, :3]
                    annotations['dimensions'] = minmax_boxes3d[:, 3:6]
                    annotations['gt_boxes_upright_depth'] = minmax_boxes3d
                    annotations['index'] = np.arange(
                        annotations['gt_num'], dtype=np.int32)
                    annotations['class'] = np.array([
119
                        self.cat_ids2class[classes[i]]
120
121
122
123
124
125
126
127
128
129
                        for i in range(annotations['gt_num'])
                    ])
                info['annos'] = annotations
            return info

        sample_id_list = sample_id_list if sample_id_list is not None \
            else self.sample_id_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)