Commit 8bf333f9 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to generate WOD gt_database with multi-frame tail

parent 60e019e3
...@@ -226,7 +226,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -226,7 +226,7 @@ class WaymoDataset(DatasetTemplate):
points_pre_all.append(points_pre) points_pre_all.append(points_pre)
num_points_pre.append(points_pre.shape[0]) num_points_pre.append(points_pre.shape[0])
points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32) points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int) num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32)
return points, num_points_all, sample_idx_pre_list return points, num_points_all, sample_idx_pre_list
def __len__(self): def __len__(self):
...@@ -490,7 +490,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -490,7 +490,8 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points = np.concatenate(stacked_gt_points, axis=0) stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
np.save(db_data_save_path, stacked_gt_points) np.save(db_data_save_path, stacked_gt_points)
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False): def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None,
total_samples=0, use_cuda=False, crop_gt_with_tail=False):
info, info_idx = info_with_idx info, info_idx = info_with_idx
print('gt_database sample: %d/%d' % (info_idx, total_samples)) print('gt_database sample: %d/%d' % (info_idx, total_samples))
...@@ -526,15 +527,34 @@ class WaymoDataset(DatasetTemplate): ...@@ -526,15 +527,34 @@ class WaymoDataset(DatasetTemplate):
if num_obj == 0: if num_obj == 0:
return {} return {}
if use_sequence_data and crop_gt_with_tail:
assert gt_boxes.shape[1] == 9
speed = gt_boxes[:, 7:9]
sequence_cfg = self.dataset_cfg.SEQUENCE_CONFIG
assert sequence_cfg.SAMPLE_OFFSET[1] == 0
assert sequence_cfg.SAMPLE_OFFSET[0] < 0
num_frames = sequence_cfg.SAMPLE_OFFSET[1] - sequence_cfg.SAMPLE_OFFSET[0] + 1
assert num_frames > 1
latest_center = gt_boxes[:, 0:2]
oldest_center = latest_center - speed * (num_frames - 1) * 0.1
new_center = (latest_center + oldest_center) * 0.5
new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1)
gt_boxes_crop = gt_boxes.copy()
gt_boxes_crop[:, 0:2] = new_center
gt_boxes_crop[:, 3] = new_length
else:
gt_boxes_crop = gt_boxes
if use_cuda: if use_cuda:
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu( box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(), torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda() torch.from_numpy(gt_boxes_crop[:, 0:7]).unsqueeze(dim=0).float().cuda()
).long().squeeze(dim=0).cpu().numpy() ).long().squeeze(dim=0).cpu().numpy()
else: else:
box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu( box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
torch.from_numpy(points[:, 0:3]).float(), torch.from_numpy(points[:, 0:3]).float(),
torch.from_numpy(gt_boxes[:, 0:7]).float() torch.from_numpy(gt_boxes_crop[:, 0:7]).float()
).long().numpy() # (num_boxes, num_points) ).long().numpy() # (num_boxes, num_points)
for i in range(num_obj): for i in range(num_obj):
...@@ -556,7 +576,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -556,7 +576,8 @@ class WaymoDataset(DatasetTemplate):
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name, db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i], 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]} 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i],
'box3d_crop': gt_boxes_crop[i]}
if names[i] in all_db_infos: if names[i] in all_db_infos:
all_db_infos[names[i]].append(db_info) all_db_infos[names[i]].append(db_info)
...@@ -564,15 +585,16 @@ class WaymoDataset(DatasetTemplate): ...@@ -564,15 +585,16 @@ class WaymoDataset(DatasetTemplate):
all_db_infos[names[i]] = [db_info] all_db_infos[names[i]] = [db_info]
return all_db_infos return all_db_infos
def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, processed_data_tag=None, num_workers=16): def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
processed_data_tag=None, num_workers=16, crop_gt_with_tail=False):
use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
if use_sequence_data: if use_sequence_data:
st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1] st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
st_frame = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames) self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame)) st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame)) database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sparallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame)) db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_%sparallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sglobal_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
else: else:
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval)) database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval)) db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
...@@ -587,9 +609,10 @@ class WaymoDataset(DatasetTemplate): ...@@ -587,9 +609,10 @@ class WaymoDataset(DatasetTemplate):
create_gt_database_of_single_scene = partial( create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene, self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path, use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, total_samples=len(infos), use_cuda=False used_classes=used_classes, total_samples=len(infos), use_cuda=False,
crop_gt_with_tail=crop_gt_with_tail
) )
# create_gt_database_of_single_scene((infos[0], 0)) # create_gt_database_of_single_scene((infos[300], 0))
with multiprocessing.Pool(num_workers) as p: with multiprocessing.Pool(num_workers) as p:
all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos))))) all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
...@@ -658,7 +681,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path, ...@@ -658,7 +681,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def create_waymo_gt_database( def create_waymo_gt_database(
dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data', dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
workers=min(16, multiprocessing.cpu_count()), use_parallel=False): workers=min(16, multiprocessing.cpu_count()), use_parallel=False, crop_gt_with_tail=False):
dataset = WaymoDataset( dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger() training=False, logger=common_utils.create_logger()
...@@ -673,7 +696,7 @@ def create_waymo_gt_database( ...@@ -673,7 +696,7 @@ def create_waymo_gt_database(
dataset.create_groundtruth_database_parallel( dataset.create_groundtruth_database_parallel(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1, info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag, used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
num_workers=workers num_workers=workers, crop_gt_with_tail=crop_gt_with_tail
) )
else: else:
dataset.create_groundtruth_database( dataset.create_groundtruth_database(
...@@ -694,6 +717,7 @@ if __name__ == '__main__': ...@@ -694,6 +717,7 @@ if __name__ == '__main__':
parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='') parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
parser.add_argument('--update_info_only', action='store_true', default=False, help='') parser.add_argument('--update_info_only', action='store_true', default=False, help='')
parser.add_argument('--use_parallel', action='store_true', default=False, help='') parser.add_argument('--use_parallel', action='store_true', default=False, help='')
parser.add_argument('--crop_gt_with_tail', action='store_true', default=False, help='')
args = parser.parse_args() args = parser.parse_args()
...@@ -728,7 +752,8 @@ if __name__ == '__main__': ...@@ -728,7 +752,8 @@ if __name__ == '__main__':
data_path=ROOT_DIR / 'data' / 'waymo', data_path=ROOT_DIR / 'data' / 'waymo',
save_path=ROOT_DIR / 'data' / 'waymo', save_path=ROOT_DIR / 'data' / 'waymo',
processed_data_tag=args.processed_data_tag, processed_data_tag=args.processed_data_tag,
use_parallel=args.use_parallel use_parallel=args.use_parallel,
crop_gt_with_tail=args.crop_gt_with_tail
) )
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -3,6 +3,7 @@ import pickle as pkl ...@@ -3,6 +3,7 @@ import pickle as pkl
from pathlib import Path from pathlib import Path
import tqdm import tqdm
import copy import copy
import os
def create_integrated_db_with_infos(args, root_path): def create_integrated_db_with_infos(args, root_path):
...@@ -13,8 +14,8 @@ def create_integrated_db_with_infos(args, root_path): ...@@ -13,8 +14,8 @@ def create_integrated_db_with_infos(args, root_path):
""" """
# prepare # prepare
db_infos_path = root_path / args.src_db_info db_infos_path = args.src_db_info
db_info_global_path = str(db_infos_path)[:-4] + '_global' + '.pkl' db_info_global_path = db_infos_path
global_db_path = root_path / (args.new_db_name + '.npy') global_db_path = root_path / (args.new_db_name + '.npy')
db_infos = pkl.load(open(db_infos_path, 'rb')) db_infos = pkl.load(open(db_infos_path, 'rb'))
...@@ -71,17 +72,14 @@ if __name__ == '__main__': ...@@ -71,17 +72,14 @@ if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser(description='arg parser') parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--root_path', type=str, default=None, help='specify the root path') parser.add_argument('--src_db_info', type=str, default='../../data/waymo/waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0_tail_parallel.pkl', help='')
parser.add_argument('--src_db_info', type=str, default='waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl', help='') parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_tail_parallel_global', help='')
parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global', help='') parser.add_argument('--num_point_features', type=int, default=6, help='number of feature channels for points')
parser.add_argument('--num_point_features', type=int, default=5, parser.add_argument('--class_name', type=str, default='Vehicle', help='category name for verification')
help='number of feature channels for points')
parser.add_argument('--class_name', type=str, default='Vehicle',
help='category name for verification')
args = parser.parse_args() args = parser.parse_args()
root_path = Path(args.root_path) root_path = Path(os.path.dirname(args.src_db_info))
db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path) db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path)
# simple verify # simple verify
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment