create_integrated_databse.py 2.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import pickle as pkl
from pathlib import Path
import tqdm
import copy


def create_integrated_db_with_infos(args, root_path):
    """
    Args:
        args:
    Returns:

    """
    # prepare
    db_infos_path = root_path / args.src_db_info
    db_info_global_path = str(db_infos_path)[:-4] + '_global' + '.pkl'
    global_db_path = root_path / (args.new_db_name + '.npy')

    db_infos = pkl.load(open(db_infos_path, 'rb'))
    db_info_global = copy.deepcopy(db_infos)
    start_idx = 0
    global_db_list = []

    for category, class_info in db_infos.items():
        print('>>> Start processing %s' % category)
        for idx, info in tqdm.tqdm(enumerate(class_info), total=len(class_info)):
            obj_path = root_path / info['path']
            obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape(
                [-1, args.num_point_features])
            num_points = obj_points.shape[0]
            db_info_global[category][idx]['global_data_offset'] = (start_idx, start_idx + num_points)
            start_idx += num_points
            global_db_list.append(obj_points)

    global_db = np.concatenate(global_db_list)

    with open(global_db_path, 'wb') as f:
        np.save(f, global_db)

    with open(db_info_global_path, 'wb') as f:
        pkl.dump(db_info_global, f)

    print(f"Successfully create integrated database at {global_db_path}")
    print(f"Successfully create integrated database info at {db_info_global_path}")

    return db_info_global, global_db


def verify(info, whole_db, root_path, num_point_features):
    obj_path = root_path / info['path']
    obj_points = np.fromfile(str(obj_path), dtype=np.float32).reshape([-1, num_point_features])
    mean_origin = obj_points.mean()

    start_idx, end_idx = info['global_data_offset']
    obj_points_new = whole_db[start_idx:end_idx]
    mean_new = obj_points_new.mean()

    assert mean_origin == mean_new

    print("Verification pass!")


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--root_path', type=str, default=None, help='specify the root path')
    parser.add_argument('--src_db_info', type=str, default='waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl', help='')
    parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global', help='')
    parser.add_argument('--num_point_features', type=int, default=5,
                        help='number of feature channels for points')
    parser.add_argument('--class_name', type=str, default='Vehicle',
                        help='category name for verification')

    args = parser.parse_args()

    root_path = Path(args.root_path)

    db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path)
    # simple verify
    verify(db_infos_global[args.class_name][0], whole_db, root_path, args.num_point_features)