scannet_data_utils.py 3.87 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
import concurrent.futures as futures
2
3
import os

4
import mmcv
5
6
7
import numpy as np


liyinhao's avatar
liyinhao committed
8
class ScanNetData(object):
liyinhao's avatar
liyinhao committed
9
    """ScanNet Data
liyinhao's avatar
liyinhao committed
10
11
12
13
14
15

    Generate scannet infos for scannet_converter

    Args:
        root_path (str): Root path of the raw data
        split (str): Set split type of the data. Default: 'train'.
liyinhao's avatar
liyinhao committed
16
    """
17
18
19
20
21

    def __init__(self, root_path, split='train'):
        self.root_dir = root_path
        self.split = split
        self.split_dir = os.path.join(root_path)
22
23
24
25
26
27
28
29
30
        self.classes = [
            'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
            'bookshelf', 'picture', 'counter', 'desk', 'curtain',
            'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
            'garbagebin'
        ]
        self.cat2label = {cat: self.classes.index(cat) for cat in self.classes}
        self.label2cat = {self.cat2label[t]: t for t in self.cat2label}
        self.cat_ids = np.array(
31
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
32
        self.cat_ids2class = {
33
            nyu40id: i
34
            for i, nyu40id in enumerate(list(self.cat_ids))
35
36
        }
        assert split in ['train', 'val', 'test']
37
38
39
40
        split_file = os.path.join(self.root_dir, 'meta_data',
                                  f'scannetv2_{split}.txt')
        mmcv.check_file_exist(split_file)
        self.sample_id_list = mmcv.list_from_file(split_file)
41
42
43
44
45
46

    def __len__(self):
        return len(self.sample_id_list)

    def get_box_label(self, idx):
        box_file = os.path.join(self.root_dir, 'scannet_train_instance_data',
liyinhao's avatar
liyinhao committed
47
                                f'{idx}_bbox.npy')
48
49
50
        assert os.path.exists(box_file)
        return np.load(box_file)

liyinhao's avatar
liyinhao committed
51
    def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
liyinhao's avatar
liyinhao committed
52
        """Get data infos.
liyinhao's avatar
liyinhao committed
53
54
55
56
57
58

        This method gets information from the raw data.

        Args:
            num_workers (int): Number of threads to be used. Default: 4.
            has_label (bool): Whether the data has label. Default: True.
liyinhao's avatar
liyinhao committed
59
60
            sample_id_list (List[int]): Index list of the sample.
                Default: None.
liyinhao's avatar
liyinhao committed
61
62
63

        Returns:
            infos (List[dict]): Information of the raw data.
liyinhao's avatar
liyinhao committed
64
        """
65
66

        def process_single_scene(sample_idx):
liyinhao's avatar
liyinhao committed
67
            print(f'{self.split} sample_idx: {sample_idx}')
68
69
70
71
72
73
74
75
76
77
78
79
80
            info = dict()
            pc_info = {'num_features': 6, 'lidar_idx': sample_idx}
            info['point_cloud'] = pc_info

            if has_label:
                annotations = {}
                boxes_with_classes = self.get_box_label(
                    sample_idx)  # k, 6 + class
                annotations['gt_num'] = boxes_with_classes.shape[0]
                if annotations['gt_num'] != 0:
                    minmax_boxes3d = boxes_with_classes[:, :-1]  # k, 6
                    classes = boxes_with_classes[:, -1]  # k, 1
                    annotations['name'] = np.array([
81
                        self.label2cat[self.cat_ids2class[classes[i]]]
82
83
84
85
86
87
88
89
                        for i in range(annotations['gt_num'])
                    ])
                    annotations['location'] = minmax_boxes3d[:, :3]
                    annotations['dimensions'] = minmax_boxes3d[:, 3:6]
                    annotations['gt_boxes_upright_depth'] = minmax_boxes3d
                    annotations['index'] = np.arange(
                        annotations['gt_num'], dtype=np.int32)
                    annotations['class'] = np.array([
90
                        self.cat_ids2class[classes[i]]
91
92
93
94
95
96
97
98
99
100
                        for i in range(annotations['gt_num'])
                    ])
                info['annos'] = annotations
            return info

        sample_id_list = sample_id_list if sample_id_list is not None \
            else self.sample_id_list
        with futures.ThreadPoolExecutor(num_workers) as executor:
            infos = executor.map(process_single_scene, sample_id_list)
        return list(infos)