Unverified Commit dd53c2ac authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Update test set to v2 version (#2285)

* support test set

* update waymo version
parent 216c1642
......@@ -35,9 +35,11 @@ def kitti_data_prep(root_path,
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_trainval_path = osp.join(out_dir, f'{info_prefix}_infos_trainval.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_pkl_infos('kitti', out_dir=out_dir, pkl_path=info_train_path)
update_pkl_infos('kitti', out_dir=out_dir, pkl_path=info_val_path)
update_pkl_infos('kitti', out_dir=out_dir, pkl_path=info_trainval_path)
update_pkl_infos('kitti', out_dir=out_dir, pkl_path=info_test_path)
create_groundtruth_database(
'KittiDataset',
root_path,
......@@ -122,11 +124,11 @@ def scannet_data_prep(root_path, info_prefix, out_dir, workers):
indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_pkl_infos('scannet', out_dir=out_dir, pkl_path=info_train_path)
update_pkl_infos('scannet', out_dir=out_dir, pkl_path=info_test_path)
update_pkl_infos('scannet', out_dir=out_dir, pkl_path=info_val_path)
update_pkl_infos('scannet', out_dir=out_dir, pkl_path=info_test_path)
def s3dis_data_prep(root_path, info_prefix, out_dir, workers):
......@@ -210,11 +212,11 @@ def waymo_data_prep(root_path,
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl')
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl')
info_trainval_path = osp.join(out_dir, f'{info_prefix}_infos_trainval.pkl')
test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl')
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_train_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_val_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_trainval_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=test_path)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_test_path)
GTDatabaseCreater(
'WaymoDataset',
out_dir,
......
......@@ -337,8 +337,9 @@ def update_nuscenes_infos(pkl_path, out_dir):
empty_img_info['lidar2cam'] = lidar2sensor.astype(
np.float32).tolist()
temp_data_info['images'][cam] = empty_img_info
num_instances = ori_info_dict['gt_boxes'].shape[0]
ignore_class_name = set()
if 'gt_boxes' in ori_info_dict:
num_instances = ori_info_dict['gt_boxes'].shape[0]
for i in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = ori_info_dict['gt_boxes'][
......@@ -353,12 +354,16 @@ def update_nuscenes_infos(pkl_path, out_dir):
empty_instance['bbox_label'])
empty_instance['velocity'] = ori_info_dict['gt_velocity'][
i, :].tolist()
empty_instance['num_lidar_pts'] = ori_info_dict['num_lidar_pts'][i]
empty_instance['num_radar_pts'] = ori_info_dict['num_radar_pts'][i]
empty_instance['bbox_3d_isvalid'] = ori_info_dict['valid_flag'][i]
empty_instance['num_lidar_pts'] = ori_info_dict[
'num_lidar_pts'][i]
empty_instance['num_radar_pts'] = ori_info_dict[
'num_radar_pts'][i]
empty_instance['bbox_3d_isvalid'] = ori_info_dict[
'valid_flag'][i]
empty_instance = clear_instance_unused_keys(empty_instance)
temp_data_info['instances'].append(empty_instance)
temp_data_info['cam_instances'] = generate_nuscenes_camera_instances(
temp_data_info[
'cam_instances'] = generate_nuscenes_camera_instances(
ori_info_dict, nusc)
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info)
......@@ -444,11 +449,12 @@ def update_kitti_infos(pkl_path, out_dir):
temp_data_info['lidar_points']['Tr_imu_to_velo'] = ori_info_dict[
'calib']['Tr_imu_to_velo'].astype(np.float32).tolist()
anns = ori_info_dict['annos']
num_instances = len(anns['name'])
cam2img = ori_info_dict['calib']['P2']
anns = ori_info_dict.get('annos', None)
ignore_class_name = set()
if anns is not None:
num_instances = len(anns['name'])
instance_list = []
for instance_id in range(num_instances):
empty_instance = get_empty_instance()
......@@ -484,12 +490,13 @@ def update_kitti_infos(pkl_path, out_dir):
empty_instance['bbox'] = anns['bbox'][instance_id].tolist()
empty_instance['truncated'] = anns['truncated'][
instance_id].tolist()
empty_instance['occluded'] = anns['occluded'][instance_id].tolist()
empty_instance['occluded'] = anns['occluded'][
instance_id].tolist()
empty_instance['alpha'] = anns['alpha'][instance_id].tolist()
empty_instance['score'] = anns['score'][instance_id].tolist()
empty_instance['index'] = anns['index'][instance_id].tolist()
empty_instance['group_id'] = anns['group_ids'][instance_id].tolist(
)
empty_instance['group_id'] = anns['group_ids'][
instance_id].tolist()
empty_instance['difficulty'] = anns['difficulty'][
instance_id].tolist()
empty_instance['num_lidar_pts'] = anns['num_points_in_gt'][
......@@ -537,8 +544,10 @@ def update_s3dis_infos(pkl_path, out_dir):
'point_cloud']['num_features']
temp_data_info['lidar_points']['lidar_path'] = Path(
ori_info_dict['pts_path']).name
if 'pts_semantic_mask_path' in ori_info_dict:
temp_data_info['pts_semantic_mask_path'] = Path(
ori_info_dict['pts_semantic_mask_path']).name
if 'pts_instance_mask_path' in ori_info_dict:
temp_data_info['pts_instance_mask_path'] = Path(
ori_info_dict['pts_instance_mask_path']).name
......@@ -611,16 +620,19 @@ def update_scannet_infos(pkl_path, out_dir):
'point_cloud']['num_features']
temp_data_info['lidar_points']['lidar_path'] = Path(
ori_info_dict['pts_path']).name
if 'pts_semantic_mask_path' in ori_info_dict:
temp_data_info['pts_semantic_mask_path'] = Path(
ori_info_dict['pts_semantic_mask_path']).name
if 'pts_instance_mask_path' in ori_info_dict:
temp_data_info['pts_instance_mask_path'] = Path(
ori_info_dict['pts_instance_mask_path']).name
# TODO support camera
# np.linalg.inv(info['axis_align_matrix'] @ extrinsic): depth2cam
anns = ori_info_dict['annos']
temp_data_info['axis_align_matrix'] = anns['axis_align_matrix'].tolist(
)
anns = ori_info_dict.get('annos', None)
if anns is not None:
temp_data_info['axis_align_matrix'] = anns[
'axis_align_matrix'].tolist()
if anns['gt_num'] == 0:
instance_list = []
else:
......@@ -696,7 +708,8 @@ def update_sunrgbd_infos(pkl_path, out_dir):
temp_data_info['images']['CAM0']['height'] = h
temp_data_info['images']['CAM0']['width'] = w
anns = ori_info_dict['annos']
anns = ori_info_dict.get('annos', None)
if anns is not None:
if anns['gt_num'] == 0:
instance_list = []
else:
......@@ -818,8 +831,9 @@ def update_lyft_infos(pkl_path, out_dir):
empty_img_info['lidar2cam'] = lidar2sensor.astype(
np.float32).tolist()
temp_data_info['images'][cam] = empty_img_info
num_instances = ori_info_dict['gt_boxes'].shape[0]
ignore_class_name = set()
if 'gt_boxes' in ori_info_dict:
num_instances = ori_info_dict['gt_boxes'].shape[0]
for i in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = ori_info_dict['gt_boxes'][
......@@ -954,7 +968,7 @@ def update_waymo_infos(pkl_path, out_dir):
temp_data_info['lidar_sweeps'].append(lidar_sweep)
temp_data_info['image_sweeps'].append(image_sweep)
anns = ori_info_dict.get('annos')
anns = ori_info_dict.get('annos', None)
ignore_class_name = set()
if anns is not None:
num_instances = len(anns['name'])
......@@ -1001,7 +1015,7 @@ def update_waymo_infos(pkl_path, out_dir):
temp_data_info['instances'] = instance_list
# waymo provide the labels that sync with cam
anns = ori_info_dict.get('cam_sync_annos')
anns = ori_info_dict.get('cam_sync_annos', None)
ignore_class_name = set()
if anns is not None:
num_instances = len(anns['name'])
......@@ -1060,7 +1074,7 @@ def update_waymo_infos(pkl_path, out_dir):
for ignore_class in ignore_class_name:
metainfo['categories'][ignore_class] = -1
metainfo['dataset'] = 'waymo'
metainfo['version'] = '1.2'
metainfo['version'] = '1.4'
metainfo['info_version'] = '1.1'
converted_data_info = dict(metainfo=metainfo, data_list=converted_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment