"docs/zh_cn/tutorials/config.md" did not exist on "0c30672cc5b1d9ab9f14f7b4035f85de58f18c4b"
batch_load_scannet_data.py 5.81 KB
Newer Older
1
2
# Modified from
# https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
7
"""Batch mode in loading Scannet scenes with vertices and ground truth labels
wangtai's avatar
wangtai committed
8
for semantic and instance segmentations.
liyinhao's avatar
liyinhao committed
9
10
11

Usage example: python ./batch_load_scannet_data.py
"""
12
import argparse
liyinhao's avatar
liyinhao committed
13
import datetime
wangtai's avatar
wangtai committed
14
15
import os
from os import path as osp
liyinhao's avatar
liyinhao committed
16

17
18
19
import numpy as np
from load_scannet_data import export

liyinhao's avatar
liyinhao committed
20
21
22
23
24
DONOTCARE_CLASS_IDS = np.array([])
OBJ_CLASS_IDS = np.array(
    [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])


25
26
27
28
29
30
def export_one_scan(scan_name,
                    output_filename_prefix,
                    max_num_point,
                    label_map_file,
                    scannet_dir,
                    test_mode=False):
31
32
33
34
35
36
37
    mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply')
    agg_file = osp.join(scannet_dir, scan_name,
                        scan_name + '.aggregation.json')
    seg_file = osp.join(scannet_dir, scan_name,
                        scan_name + '_vh_clean_2.0.010000.segs.json')
    # includes axisAlignment info for the train set scans.
    meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt')
38
39
40
41
    mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \
        aligned_bboxes, instance2semantic, axis_align_matrix = export(
            mesh_file, agg_file, seg_file, meta_file, label_map_file, None,
            test_mode)
liyinhao's avatar
liyinhao committed
42

43
44
45
46
47
    if not test_mode:
        mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
        mesh_vertices = mesh_vertices[mask, :]
        semantic_labels = semantic_labels[mask]
        instance_labels = instance_labels[mask]
liyinhao's avatar
liyinhao committed
48

49
50
        num_instances = len(np.unique(instance_labels))
        print(f'Num of instances: {num_instances}')
liyinhao's avatar
liyinhao committed
51

52
53
54
55
56
57
        bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS)
        unaligned_bboxes = unaligned_bboxes[bbox_mask, :]
        bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS)
        aligned_bboxes = aligned_bboxes[bbox_mask, :]
        assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0]
        print(f'Num of care instances: {unaligned_bboxes.shape[0]}')
liyinhao's avatar
liyinhao committed
58

59
    if max_num_point is not None:
60
        max_num_point = int(max_num_point)
61
62
63
64
        N = mesh_vertices.shape[0]
        if N > max_num_point:
            choices = np.random.choice(N, max_num_point, replace=False)
            mesh_vertices = mesh_vertices[choices, :]
65
66
67
            if not test_mode:
                semantic_labels = semantic_labels[choices]
                instance_labels = instance_labels[choices]
liyinhao's avatar
liyinhao committed
68

69
    np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices)
70
71
72
    if not test_mode:
        np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
        np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels)
73
74
75
76
77
        np.save(f'{output_filename_prefix}_unaligned_bbox.npy',
                unaligned_bboxes)
        np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes)
        np.save(f'{output_filename_prefix}_axis_align_matrix.npy',
                axis_align_matrix)
78
79
80
81
82
83
84
85
86
87
88


def batch_export(max_num_point,
                 output_folder,
                 scan_names_file,
                 label_map_file,
                 scannet_dir,
                 test_mode=False):
    if test_mode and not os.path.exists(scannet_dir):
        # test data preparation is optional
        return
89
    if not os.path.exists(output_folder):
90
        print(f'Creating new data folder: {output_folder}')
91
        os.mkdir(output_folder)
liyinhao's avatar
liyinhao committed
92

93
94
    scan_names = [line.rstrip() for line in open(scan_names_file)]
    for scan_name in scan_names:
liyinhao's avatar
liyinhao committed
95
96
97
        print('-' * 20 + 'begin')
        print(datetime.datetime.now())
        print(scan_name)
98
99
        output_filename_prefix = osp.join(output_folder, scan_name)
        if osp.isfile(f'{output_filename_prefix}_vert.npy'):
liyinhao's avatar
liyinhao committed
100
101
102
103
            print('File already exists. skipping.')
            print('-' * 20 + 'done')
            continue
        try:
104
            export_one_scan(scan_name, output_filename_prefix, max_num_point,
105
                            label_map_file, scannet_dir, test_mode)
liyinhao's avatar
liyinhao committed
106
        except Exception:
liyinhao's avatar
liyinhao committed
107
            print(f'Failed export scan: {scan_name}')
liyinhao's avatar
liyinhao committed
108
109
110
        print('-' * 20 + 'done')


111
112
113
114
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--max_num_point',
115
        default=None,
116
117
118
        help='The maximum number of the points.')
    parser.add_argument(
        '--output_folder',
119
        default='./scannet_instance_data',
120
121
        help='output folder of the result.')
    parser.add_argument(
122
123
124
125
126
        '--train_scannet_dir', default='scans', help='scannet data directory.')
    parser.add_argument(
        '--test_scannet_dir',
        default='scans_test',
        help='scannet data directory.')
127
128
129
130
131
132
133
134
    parser.add_argument(
        '--label_map_file',
        default='meta_data/scannetv2-labels.combined.tsv',
        help='The path of label map file.')
    parser.add_argument(
        '--train_scan_names_file',
        default='meta_data/scannet_train.txt',
        help='The path of the file that stores the scan names.')
135
136
137
138
    parser.add_argument(
        '--test_scan_names_file',
        default='meta_data/scannetv2_test.txt',
        help='The path of the file that stores the scan names.')
139
    args = parser.parse_args()
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    batch_export(
        args.max_num_point,
        args.output_folder,
        args.train_scan_names_file,
        args.label_map_file,
        args.train_scannet_dir,
        test_mode=False)
    batch_export(
        args.max_num_point,
        args.output_folder,
        args.test_scan_names_file,
        args.label_map_file,
        args.test_scannet_dir,
        test_mode=True)
154
155


liyinhao's avatar
liyinhao committed
156
if __name__ == '__main__':
157
    main()