Commit 38474d6c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to use multithreads (cpu/cuda) to generate gt_database of WOD

parent 6a1f253a
......@@ -13,6 +13,8 @@ import SharedArray
import torch.distributed as dist
from tqdm import tqdm
from pathlib import Path
from functools import partial
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils
from ..dataset import DatasetTemplate
......@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate):
return sequence_file
def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1, update_info_only=False):
from functools import partial
from . import waymo_utils
print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
% (sampled_interval, len(self.sample_sequence_list)))
......@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate):
point_offset_cnt = 0
stacked_gt_points = []
for k in range(0, len(infos), sampled_interval):
for k in tqdm(range(0, len(infos), sampled_interval)):
print('gt_database sample: %d/%d' % (k + 1, len(infos)))
info = infos[k]
......@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
np.save(db_data_save_path, stacked_gt_points)
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, use_cuda=False):
info, info_idx = info_with_idx
all_db_infos = {}
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
points = self.get_lidar(sequence_name, sample_idx)
if use_sequence_data:
points, num_points_all, sample_idx_pre_list = self.get_sequence_data(
info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
)
annos = info['annos']
names = annos['name']
difficulty = annos['difficulty']
gt_boxes = annos['gt_boxes_lidar']
if info_idx % 4 != 0 and len(names) > 0:
mask = (names == 'Vehicle')
names = names[~mask]
difficulty = difficulty[~mask]
gt_boxes = gt_boxes[~mask]
if info_idx % 2 != 0 and len(names) > 0:
mask = (names == 'Pedestrian')
names = names[~mask]
difficulty = difficulty[~mask]
gt_boxes = gt_boxes[~mask]
num_obj = gt_boxes.shape[0]
if num_obj == 0:
return {}
if use_cuda:
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
).long().squeeze(dim=0).cpu().numpy()
else:
box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
torch.from_numpy(points[:, 0:3]).float(),
torch.from_numpy(gt_boxes[:, 0:7]).float()
).long().numpy() # (num_boxes, num_points)
for i in range(num_obj):
filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
filepath = database_save_path / filename
if use_cuda:
gt_points = points[box_idxs_of_pts == i]
else:
gt_points = points[box_point_mask[i] > 0]
gt_points[:, :3] -= gt_boxes[i, :3]
if (used_classes is None) or names[i] in used_classes:
with open(filepath, 'w') as f:
gt_points.tofile(f)
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
if names[i] in all_db_infos:
all_db_infos[names[i]].append(db_info)
else:
all_db_infos[names[i]] = [db_info]
return all_db_infos
def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, processed_data_tag=None, num_workers=16):
use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
if use_sequence_data:
st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
st_frame = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
else:
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global_parallel.npy' % (processed_data_tag, split, sampled_interval))
database_save_path.mkdir(parents=True, exist_ok=True)
with open(info_path, 'rb') as f:
infos = pickle.load(f)
create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, use_cuda=True
)
# create_gt_database_of_single_scene((infos[0], 0))
with multiprocessing.Pool(num_workers) as p:
all_db_infos_list = list(tqdm(p.imap(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))), total=len(infos)))
all_db_infos = {}
for cur_db_infos in all_db_infos_list:
for key, val in cur_db_infos.items():
if key not in all_db_infos:
all_db_infos[key] = val
else:
all_db_infos[key].extend(val)
for k, v in all_db_infos.items():
print('Database %s: %d' % (k, len(v)))
with open(db_info_save_path, 'wb') as f:
pickle.dump(all_db_infos, f)
def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
......@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def create_waymo_gt_database(
dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
workers=min(16, multiprocessing.cpu_count())):
workers=min(16, multiprocessing.cpu_count()), use_parallel=False):
dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger()
......@@ -549,10 +662,17 @@ def create_waymo_gt_database(
print('---------------Start create groundtruth database for data augmentation---------------')
dataset.set_split(train_split)
dataset.create_groundtruth_database(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
)
if use_parallel:
dataset.create_groundtruth_database_parallel(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
num_workers=workers
)
else:
dataset.create_groundtruth_database(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
)
print('---------------Data preparation Done---------------')
......@@ -566,6 +686,7 @@ if __name__ == '__main__':
parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
parser.add_argument('--update_info_only', action='store_true', default=False, help='')
parser.add_argument('--use_parallel', action='store_true', default=False, help='')
args = parser.parse_args()
......@@ -599,7 +720,8 @@ if __name__ == '__main__':
class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
data_path=ROOT_DIR / 'data' / 'waymo',
save_path=ROOT_DIR / 'data' / 'waymo',
processed_data_tag=args.processed_data_tag
processed_data_tag=args.processed_data_tag,
use_parallel=args.use_parallel
)
else:
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment