batch_load_scannet_data.py 5.81 KB
Newer Older
1
2
# Modified from
# https://github.com/facebookresearch/votenet/blob/master/scannet/batch_load_scannet_data.py
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
7
"""Batch mode in loading Scannet scenes with vertices and ground truth labels
wangtai's avatar
wangtai committed
8
for semantic and instance segmentations.
liyinhao's avatar
liyinhao committed
9
10
11

Usage example: python ./batch_load_scannet_data.py
"""
12
import argparse
liyinhao's avatar
liyinhao committed
13
14
import datetime
import numpy as np
wangtai's avatar
wangtai committed
15
import os
liyinhao's avatar
liyinhao committed
16
from load_scannet_data import export
wangtai's avatar
wangtai committed
17
from os import path as osp
liyinhao's avatar
liyinhao committed
18
19
20
21
22
23

DONOTCARE_CLASS_IDS = np.array([])
OBJ_CLASS_IDS = np.array(
    [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])


24
25
26
27
28
29
def export_one_scan(scan_name,
                    output_filename_prefix,
                    max_num_point,
                    label_map_file,
                    scannet_dir,
                    test_mode=False):
30
31
32
33
34
35
36
    mesh_file = osp.join(scannet_dir, scan_name, scan_name + '_vh_clean_2.ply')
    agg_file = osp.join(scannet_dir, scan_name,
                        scan_name + '.aggregation.json')
    seg_file = osp.join(scannet_dir, scan_name,
                        scan_name + '_vh_clean_2.0.010000.segs.json')
    # includes axisAlignment info for the train set scans.
    meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt')
37
38
39
40
    mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \
        aligned_bboxes, instance2semantic, axis_align_matrix = export(
            mesh_file, agg_file, seg_file, meta_file, label_map_file, None,
            test_mode)
liyinhao's avatar
liyinhao committed
41

42
43
44
45
46
    if not test_mode:
        mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
        mesh_vertices = mesh_vertices[mask, :]
        semantic_labels = semantic_labels[mask]
        instance_labels = instance_labels[mask]
liyinhao's avatar
liyinhao committed
47

48
49
        num_instances = len(np.unique(instance_labels))
        print(f'Num of instances: {num_instances}')
liyinhao's avatar
liyinhao committed
50

51
52
53
54
55
56
        bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS)
        unaligned_bboxes = unaligned_bboxes[bbox_mask, :]
        bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS)
        aligned_bboxes = aligned_bboxes[bbox_mask, :]
        assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0]
        print(f'Num of care instances: {unaligned_bboxes.shape[0]}')
liyinhao's avatar
liyinhao committed
57

58
    if max_num_point is not None:
59
        max_num_point = int(max_num_point)
60
61
62
63
        N = mesh_vertices.shape[0]
        if N > max_num_point:
            choices = np.random.choice(N, max_num_point, replace=False)
            mesh_vertices = mesh_vertices[choices, :]
64
65
66
            if not test_mode:
                semantic_labels = semantic_labels[choices]
                instance_labels = instance_labels[choices]
liyinhao's avatar
liyinhao committed
67

68
    np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices)
69
70
71
    if not test_mode:
        np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
        np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels)
72
73
74
75
76
        np.save(f'{output_filename_prefix}_unaligned_bbox.npy',
                unaligned_bboxes)
        np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes)
        np.save(f'{output_filename_prefix}_axis_align_matrix.npy',
                axis_align_matrix)
77
78
79
80
81
82
83
84
85
86
87


def batch_export(max_num_point,
                 output_folder,
                 scan_names_file,
                 label_map_file,
                 scannet_dir,
                 test_mode=False):
    if test_mode and not os.path.exists(scannet_dir):
        # test data preparation is optional
        return
88
    if not os.path.exists(output_folder):
89
        print(f'Creating new data folder: {output_folder}')
90
        os.mkdir(output_folder)
liyinhao's avatar
liyinhao committed
91

92
93
    scan_names = [line.rstrip() for line in open(scan_names_file)]
    for scan_name in scan_names:
liyinhao's avatar
liyinhao committed
94
95
96
        print('-' * 20 + 'begin')
        print(datetime.datetime.now())
        print(scan_name)
97
98
        output_filename_prefix = osp.join(output_folder, scan_name)
        if osp.isfile(f'{output_filename_prefix}_vert.npy'):
liyinhao's avatar
liyinhao committed
99
100
101
102
            print('File already exists. skipping.')
            print('-' * 20 + 'done')
            continue
        try:
103
            export_one_scan(scan_name, output_filename_prefix, max_num_point,
104
                            label_map_file, scannet_dir, test_mode)
liyinhao's avatar
liyinhao committed
105
        except Exception:
liyinhao's avatar
liyinhao committed
106
            print(f'Failed export scan: {scan_name}')
liyinhao's avatar
liyinhao committed
107
108
109
        print('-' * 20 + 'done')


110
111
112
113
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--max_num_point',
114
        default=None,
115
116
117
        help='The maximum number of the points.')
    parser.add_argument(
        '--output_folder',
118
        default='./scannet_instance_data',
119
120
        help='output folder of the result.')
    parser.add_argument(
121
122
123
124
125
        '--train_scannet_dir', default='scans', help='scannet data directory.')
    parser.add_argument(
        '--test_scannet_dir',
        default='scans_test',
        help='scannet data directory.')
126
127
128
129
130
131
132
133
    parser.add_argument(
        '--label_map_file',
        default='meta_data/scannetv2-labels.combined.tsv',
        help='The path of label map file.')
    parser.add_argument(
        '--train_scan_names_file',
        default='meta_data/scannet_train.txt',
        help='The path of the file that stores the scan names.')
134
135
136
137
    parser.add_argument(
        '--test_scan_names_file',
        default='meta_data/scannetv2_test.txt',
        help='The path of the file that stores the scan names.')
138
    args = parser.parse_args()
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    batch_export(
        args.max_num_point,
        args.output_folder,
        args.train_scan_names_file,
        args.label_map_file,
        args.train_scannet_dir,
        test_mode=False)
    batch_export(
        args.max_num_point,
        args.output_folder,
        args.test_scan_names_file,
        args.label_map_file,
        args.test_scannet_dir,
        test_mode=True)
153
154


liyinhao's avatar
liyinhao committed
155
if __name__ == '__main__':
156
    main()