Commit 6bd8be71 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

use p.map instead of p.imap for generating WOD gt_database

parent 38474d6c
...@@ -488,8 +488,10 @@ class WaymoDataset(DatasetTemplate): ...@@ -488,8 +488,10 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points = np.concatenate(stacked_gt_points, axis=0) stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
np.save(db_data_save_path, stacked_gt_points) np.save(db_data_save_path, stacked_gt_points)
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, use_cuda=False): def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False):
info, info_idx = info_with_idx info, info_idx = info_with_idx
print('gt_database sample: %d/%d' % (info_idx, total_samples))
all_db_infos = {} all_db_infos = {}
pc_info = info['point_cloud'] pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence'] sequence_name = pc_info['lidar_sequence']
...@@ -577,14 +579,15 @@ class WaymoDataset(DatasetTemplate): ...@@ -577,14 +579,15 @@ class WaymoDataset(DatasetTemplate):
with open(info_path, 'rb') as f: with open(info_path, 'rb') as f:
infos = pickle.load(f) infos = pickle.load(f)
print(f'Number workers: {num_workers}')
create_gt_database_of_single_scene = partial( create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene, self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path, use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, use_cuda=True used_classes=used_classes, total_samples=len(infos), use_cuda=False
) )
# create_gt_database_of_single_scene((infos[0], 0)) # create_gt_database_of_single_scene((infos[0], 0))
with multiprocessing.Pool(num_workers) as p: with multiprocessing.Pool(num_workers) as p:
all_db_infos_list = list(tqdm(p.imap(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))), total=len(infos))) all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
all_db_infos = {} all_db_infos = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment