scannet_data_utils.py 4 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import concurrent.futures as futures
2
3
4
5
6
import os

import numpy as np


liyinhao's avatar
liyinhao committed
7
class ScanNetData(object):
liyinhao's avatar
liyinhao committed
8
9
10
11
12
13
14
15
16
    '''
    ScanNet Data

    Generate scannet infos for scannet_converter

    Args:
        root_path (str): Root path of the raw data
        split (str): Set split type of the data. Default: 'train'.
    '''
17
18
19
20
21

    def __init__(self, root_path, split='train'):
        self.root_dir = root_path
        self.split = split
        self.split_dir = os.path.join(root_path)
22
23
24
25
26
27
28
29
30
        self.classes = [
            'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
            'bookshelf', 'picture', 'counter', 'desk', 'curtain',
            'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
            'garbagebin'
        ]
        self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
        self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
        self.cat_ids = np.array(
31
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
32
        self.cat_ids2class = {
33
            nyu40id: i
34
            for i, nyu40id in enumerate(list(self.cat_ids))
35
36
37
        }
        assert split in ['train', 'val', 'test']
        split_dir = os.path.join(self.root_dir, 'meta_data',
liyinhao's avatar
liyinhao committed
38
                                 f'scannetv2_{split}.txt')
39
40
41
42
43
44
45
46
        self.sample_id_list = [x.strip() for x in open(split_dir).readlines()
                               ] if os.path.exists(split_dir) else None

    def __len__(self):
        return len(self.sample_id_list)

    def get_box_label(self, idx):
        box_file = os.path.join(self.root_dir, 'scannet_train_instance_data',
liyinhao's avatar
liyinhao committed
47
                                f'{idx}_bbox.npy')
48
49
50
51
52
53
54
        assert os.path.exists(box_file)
        return np.load(box_file)

    def get_scannet_infos(self,
                          num_workers=4,
                          has_label=True,
                          sample_id_list=None):
liyinhao's avatar
liyinhao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
        '''
        Get scannet infos.

        This method gets information from the raw data.

        Args:
            num_workers (int): Number of threads to be used. Default: 4.
            has_label (bool): Whether the data has label. Default: True.
            sample_id_list (List[int]): Index list of the sample. Default: None. # noqa: E501

        Returns:
            infos (List[dict]): Information of the raw data.
        '''
68
69

        def process_single_scene(sample_idx):
liyinhao's avatar
liyinhao committed
70
            print(f'{self.split} sample_idx: {sample_idx}')
71
72
73
74
75
76
77
78
79
80
81
82
83
            info = dict()
            pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
            info['point_cloud'] = pc_info

            if has_label:
                annotations = {}
                boxes_with_classes = self.get_box_label(
                    sample_idx)  # k, 6 + class
                annotations['gt_num'] = boxes_with_classes.shape[0]
                if annotations['gt_num'] != 0:
                    minmax_boxes3d = boxes_with_classes[:, :-1]  # k, 6
                    classes = boxes_with_classes[:, -1]  # k, 1
                    annotations['name'] = np.array([
84
                        self.label2cat[self.cat_ids2class[classes[i]]]
85
86
87
88
89
90
91
92
                        for i in range(annotations['gt_num'])
                    ])
                    annotations['location'] = minmax_boxes3d[:, :3]
                    annotations['dimensions'] = minmax_boxes3d[:, 3:6]
                    annotations['gt_boxes_upright_depth'] = minmax_boxes3d
                    annotations['index'] = np.arange(
                        annotations['gt_num'], dtype=np.int32)
                    annotations['class'] = np.array([
93
                        self.cat_ids2class[classes[i]]
94
95
96
97
98
99
100
101
102
103
                        for i in range(annotations['gt_num'])
                    ])
                info['annos'] = annotations
            return info

        sample_id_list = sample_id_list if sample_id_list is not None \
            else self.sample_id_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)