Commit dff4b320 authored by liyinhao's avatar liyinhao
Browse files

Change paths

parent fba104af
...@@ -186,11 +186,10 @@ def create_groundtruth_database(dataset_class_name, ...@@ -186,11 +186,10 @@ def create_groundtruth_database(dataset_class_name,
dataset = build_dataset(dataset_cfg) dataset = build_dataset(dataset_cfg)
if database_save_path is None: if database_save_path is None:
database_save_path = osp.join(data_path, database_save_path = osp.join(data_path, f'{info_prefix}_gt_database')
'{}_gt_database'.format(info_prefix))
if db_info_save_path is None: if db_info_save_path is None:
db_info_save_path = osp.join( db_info_save_path = osp.join(data_path,
data_path, '{}_dbinfos_train.pkl'.format(info_prefix)) f'{info_prefix}_dbinfos_train.pkl')
mmcv.mkdir_or_exist(database_save_path) mmcv.mkdir_or_exist(database_save_path)
all_db_infos = dict() all_db_infos = dict()
...@@ -229,7 +228,7 @@ def create_groundtruth_database(dataset_class_name, ...@@ -229,7 +228,7 @@ def create_groundtruth_database(dataset_class_name,
gt_boxes = annos['gt_bboxes'] gt_boxes = annos['gt_bboxes']
img_path = osp.split(example['img_info']['filename'])[-1] img_path = osp.split(example['img_info']['filename'])[-1]
if img_path not in file2id.keys(): if img_path not in file2id.keys():
print('skip image {} for empty mask'.format(img_path)) print(f'skip image {img_path} for empty mask')
continue continue
img_id = file2id[img_path] img_id = file2id[img_path]
kins_annIds = coco.getAnnIds(imgIds=img_id) kins_annIds = coco.getAnnIds(imgIds=img_id)
...@@ -257,7 +256,8 @@ def create_groundtruth_database(dataset_class_name, ...@@ -257,7 +256,8 @@ def create_groundtruth_database(dataset_class_name,
for i in range(num_obj): for i in range(num_obj):
filename = f'{image_idx}_{names[i]}_{i}.bin' filename = f'{image_idx}_{names[i]}_{i}.bin'
filepath = osp.join(database_save_path, filename) abs_filepath = osp.join(database_save_path, filename)
rel_filepath = osp.join(f'{info_prefix}_gt_database', filename)
# save point clouds and image patches for each object # save point clouds and image patches for each object
gt_points = points[point_indices[:, i]] gt_points = points[point_indices[:, i]]
...@@ -267,18 +267,18 @@ def create_groundtruth_database(dataset_class_name, ...@@ -267,18 +267,18 @@ def create_groundtruth_database(dataset_class_name,
if object_masks[i].sum() == 0 or not valid_inds[i]: if object_masks[i].sum() == 0 or not valid_inds[i]:
# Skip object for empty or invalid mask # Skip object for empty or invalid mask
continue continue
img_patch_path = filepath + '.png' img_patch_path = abs_filepath + '.png'
mask_patch_path = filepath + '.mask.png' mask_patch_path = abs_filepath + '.mask.png'
mmcv.imwrite(object_img_patches[i], img_patch_path) mmcv.imwrite(object_img_patches[i], img_patch_path)
mmcv.imwrite(object_masks[i], mask_patch_path) mmcv.imwrite(object_masks[i], mask_patch_path)
with open(filepath, 'w') as f: with open(abs_filepath, 'w') as f:
gt_points.tofile(f) gt_points.tofile(f)
if (used_classes is None) or names[i] in used_classes: if (used_classes is None) or names[i] in used_classes:
db_info = { db_info = {
'name': names[i], 'name': names[i],
'path': osp.join(f'{info_prefix}_gt_database', filename), 'path': rel_filepath,
'image_idx': image_idx, 'image_idx': image_idx,
'gt_idx': i, 'gt_idx': i,
'box3d_lidar': gt_boxes_3d[i], 'box3d_lidar': gt_boxes_3d[i],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment