Commit 8bf333f9 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to generate WOD gt_database with multi-frame tail

parent 60e019e3
......@@ -226,7 +226,7 @@ class WaymoDataset(DatasetTemplate):
points_pre_all.append(points_pre)
num_points_pre.append(points_pre.shape[0])
points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int)
num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32)
return points, num_points_all, sample_idx_pre_list
def __len__(self):
......@@ -490,7 +490,8 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
np.save(db_data_save_path, stacked_gt_points)
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False):
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None,
total_samples=0, use_cuda=False, crop_gt_with_tail=False):
info, info_idx = info_with_idx
print('gt_database sample: %d/%d' % (info_idx, total_samples))
......@@ -525,16 +526,35 @@ class WaymoDataset(DatasetTemplate):
num_obj = gt_boxes.shape[0]
if num_obj == 0:
return {}
if use_sequence_data and crop_gt_with_tail:
assert gt_boxes.shape[1] == 9
speed = gt_boxes[:, 7:9]
sequence_cfg = self.dataset_cfg.SEQUENCE_CONFIG
assert sequence_cfg.SAMPLE_OFFSET[1] == 0
assert sequence_cfg.SAMPLE_OFFSET[0] < 0
num_frames = sequence_cfg.SAMPLE_OFFSET[1] - sequence_cfg.SAMPLE_OFFSET[0] + 1
assert num_frames > 1
latest_center = gt_boxes[:, 0:2]
oldest_center = latest_center - speed * (num_frames - 1) * 0.1
new_center = (latest_center + oldest_center) * 0.5
new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1)
gt_boxes_crop = gt_boxes.copy()
gt_boxes_crop[:, 0:2] = new_center
gt_boxes_crop[:, 3] = new_length
else:
gt_boxes_crop = gt_boxes
if use_cuda:
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
torch.from_numpy(gt_boxes_crop[:, 0:7]).unsqueeze(dim=0).float().cuda()
).long().squeeze(dim=0).cpu().numpy()
else:
box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
torch.from_numpy(points[:, 0:3]).float(),
torch.from_numpy(gt_boxes[:, 0:7]).float()
torch.from_numpy(gt_boxes_crop[:, 0:7]).float()
).long().numpy() # (num_boxes, num_points)
for i in range(num_obj):
......@@ -556,7 +576,8 @@ class WaymoDataset(DatasetTemplate):
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i],
'box3d_crop': gt_boxes_crop[i]}
if names[i] in all_db_infos:
all_db_infos[names[i]].append(db_info)
......@@ -564,15 +585,16 @@ class WaymoDataset(DatasetTemplate):
all_db_infos[names[i]] = [db_info]
return all_db_infos
def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, processed_data_tag=None, num_workers=16):
def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
processed_data_tag=None, num_workers=16, crop_gt_with_tail=False):
use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
if use_sequence_data:
st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
st_frame = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sparallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_%sparallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sglobal_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
else:
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
......@@ -587,9 +609,10 @@ class WaymoDataset(DatasetTemplate):
create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, total_samples=len(infos), use_cuda=False
used_classes=used_classes, total_samples=len(infos), use_cuda=False,
crop_gt_with_tail=crop_gt_with_tail
)
# create_gt_database_of_single_scene((infos[0], 0))
# create_gt_database_of_single_scene((infos[300], 0))
with multiprocessing.Pool(num_workers) as p:
all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
......@@ -658,7 +681,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def create_waymo_gt_database(
dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
workers=min(16, multiprocessing.cpu_count()), use_parallel=False):
workers=min(16, multiprocessing.cpu_count()), use_parallel=False, crop_gt_with_tail=False):
dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger()
......@@ -673,7 +696,7 @@ def create_waymo_gt_database(
dataset.create_groundtruth_database_parallel(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
num_workers=workers
num_workers=workers, crop_gt_with_tail=crop_gt_with_tail
)
else:
dataset.create_groundtruth_database(
......@@ -694,6 +717,7 @@ if __name__ == '__main__':
parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
parser.add_argument('--update_info_only', action='store_true', default=False, help='')
parser.add_argument('--use_parallel', action='store_true', default=False, help='')
parser.add_argument('--crop_gt_with_tail', action='store_true', default=False, help='')
args = parser.parse_args()
......@@ -728,7 +752,8 @@ if __name__ == '__main__':
data_path=ROOT_DIR / 'data' / 'waymo',
save_path=ROOT_DIR / 'data' / 'waymo',
processed_data_tag=args.processed_data_tag,
use_parallel=args.use_parallel
use_parallel=args.use_parallel,
crop_gt_with_tail=args.crop_gt_with_tail
)
else:
raise NotImplementedError
......@@ -3,6 +3,7 @@ import pickle as pkl
from pathlib import Path
import tqdm
import copy
import os
def create_integrated_db_with_infos(args, root_path):
......@@ -13,8 +14,8 @@ def create_integrated_db_with_infos(args, root_path):
"""
# prepare
db_infos_path = root_path / args.src_db_info
db_info_global_path = str(db_infos_path)[:-4] + '_global' + '.pkl'
db_infos_path = args.src_db_info
db_info_global_path = db_infos_path
global_db_path = root_path / (args.new_db_name + '.npy')
db_infos = pkl.load(open(db_infos_path, 'rb'))
......@@ -71,17 +72,14 @@ if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--root_path', type=str, default=None, help='specify the root path')
parser.add_argument('--src_db_info', type=str, default='waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1.pkl', help='')
parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_global', help='')
parser.add_argument('--num_point_features', type=int, default=5,
help='number of feature channels for points')
parser.add_argument('--class_name', type=str, default='Vehicle',
help='category name for verification')
parser.add_argument('--src_db_info', type=str, default='../../data/waymo/waymo_processed_data_v0_5_0_waymo_dbinfos_train_sampled_1_multiframe_-4_to_0_tail_parallel.pkl', help='')
parser.add_argument('--new_db_name', type=str, default='waymo_processed_data_v0_5_0_gt_database_train_sampled_1_multiframe_-4_to_0_tail_parallel_global', help='')
parser.add_argument('--num_point_features', type=int, default=6, help='number of feature channels for points')
parser.add_argument('--class_name', type=str, default='Vehicle', help='category name for verification')
args = parser.parse_args()
root_path = Path(args.root_path)
root_path = Path(os.path.dirname(args.src_db_info))
db_infos_global, whole_db = create_integrated_db_with_infos(args, root_path)
# simple verify
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment