Commit 38474d6c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support to use multithreads (cpu/cuda) to generate gt_database of WOD

parent 6a1f253a
...@@ -13,6 +13,8 @@ import SharedArray ...@@ -13,6 +13,8 @@ import SharedArray
import torch.distributed as dist import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from functools import partial
from ...ops.roiaware_pool3d import roiaware_pool3d_utils from ...ops.roiaware_pool3d import roiaware_pool3d_utils
from ...utils import box_utils, common_utils from ...utils import box_utils, common_utils
from ..dataset import DatasetTemplate from ..dataset import DatasetTemplate
...@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate): ...@@ -140,7 +142,6 @@ class WaymoDataset(DatasetTemplate):
return sequence_file return sequence_file
def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1, update_info_only=False): def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1, update_info_only=False):
from functools import partial
from . import waymo_utils from . import waymo_utils
print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------' print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
% (sampled_interval, len(self.sample_sequence_list))) % (sampled_interval, len(self.sample_sequence_list)))
...@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -413,7 +414,7 @@ class WaymoDataset(DatasetTemplate):
point_offset_cnt = 0 point_offset_cnt = 0
stacked_gt_points = [] stacked_gt_points = []
for k in range(0, len(infos), sampled_interval): for k in tqdm(range(0, len(infos), sampled_interval)):
print('gt_database sample: %d/%d' % (k + 1, len(infos))) print('gt_database sample: %d/%d' % (k + 1, len(infos)))
info = infos[k] info = infos[k]
...@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate): ...@@ -487,6 +488,118 @@ class WaymoDataset(DatasetTemplate):
stacked_gt_points = np.concatenate(stacked_gt_points, axis=0) stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
np.save(db_data_save_path, stacked_gt_points) np.save(db_data_save_path, stacked_gt_points)
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, use_cuda=False):
info, info_idx = info_with_idx
all_db_infos = {}
pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence']
sample_idx = pc_info['sample_idx']
points = self.get_lidar(sequence_name, sample_idx)
if use_sequence_data:
points, num_points_all, sample_idx_pre_list = self.get_sequence_data(
info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
)
annos = info['annos']
names = annos['name']
difficulty = annos['difficulty']
gt_boxes = annos['gt_boxes_lidar']
if info_idx % 4 != 0 and len(names) > 0:
mask = (names == 'Vehicle')
names = names[~mask]
difficulty = difficulty[~mask]
gt_boxes = gt_boxes[~mask]
if info_idx % 2 != 0 and len(names) > 0:
mask = (names == 'Pedestrian')
names = names[~mask]
difficulty = difficulty[~mask]
gt_boxes = gt_boxes[~mask]
num_obj = gt_boxes.shape[0]
if num_obj == 0:
return {}
if use_cuda:
box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
).long().squeeze(dim=0).cpu().numpy()
else:
box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
torch.from_numpy(points[:, 0:3]).float(),
torch.from_numpy(gt_boxes[:, 0:7]).float()
).long().numpy() # (num_boxes, num_points)
for i in range(num_obj):
filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
filepath = database_save_path / filename
if use_cuda:
gt_points = points[box_idxs_of_pts == i]
else:
gt_points = points[box_point_mask[i] > 0]
gt_points[:, :3] -= gt_boxes[i, :3]
if (used_classes is None) or names[i] in used_classes:
with open(filepath, 'w') as f:
gt_points.tofile(f)
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
if names[i] in all_db_infos:
all_db_infos[names[i]].append(db_info)
else:
all_db_infos[names[i]] = [db_info]
return all_db_infos
def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10, processed_data_tag=None, num_workers=16):
use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
if use_sequence_data:
st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
st_frame = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_parallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_parallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global_parallel.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
else:
database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global_parallel.npy' % (processed_data_tag, split, sampled_interval))
database_save_path.mkdir(parents=True, exist_ok=True)
with open(info_path, 'rb') as f:
infos = pickle.load(f)
create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, use_cuda=True
)
# create_gt_database_of_single_scene((infos[0], 0))
with multiprocessing.Pool(num_workers) as p:
all_db_infos_list = list(tqdm(p.imap(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))), total=len(infos)))
all_db_infos = {}
for cur_db_infos in all_db_infos_list:
for key, val in cur_db_infos.items():
if key not in all_db_infos:
all_db_infos[key] = val
else:
all_db_infos[key].extend(val)
for k, v in all_db_infos.items():
print('Database %s: %d' % (k, len(v)))
with open(db_info_save_path, 'wb') as f:
pickle.dump(all_db_infos, f)
def create_waymo_infos(dataset_cfg, class_names, data_path, save_path, def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
raw_data_tag='raw_data', processed_data_tag='waymo_processed_data', raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
...@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path, ...@@ -538,7 +651,7 @@ def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
def create_waymo_gt_database( def create_waymo_gt_database(
dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data', dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
workers=min(16, multiprocessing.cpu_count())): workers=min(16, multiprocessing.cpu_count()), use_parallel=False):
dataset = WaymoDataset( dataset = WaymoDataset(
dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
training=False, logger=common_utils.create_logger() training=False, logger=common_utils.create_logger()
...@@ -549,10 +662,17 @@ def create_waymo_gt_database( ...@@ -549,10 +662,17 @@ def create_waymo_gt_database(
print('---------------Start create groundtruth database for data augmentation---------------') print('---------------Start create groundtruth database for data augmentation---------------')
dataset.set_split(train_split) dataset.set_split(train_split)
dataset.create_groundtruth_database( if use_parallel:
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1, dataset.create_groundtruth_database_parallel(
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
) used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
num_workers=workers
)
else:
dataset.create_groundtruth_database(
info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
)
print('---------------Data preparation Done---------------') print('---------------Data preparation Done---------------')
...@@ -566,6 +686,7 @@ if __name__ == '__main__': ...@@ -566,6 +686,7 @@ if __name__ == '__main__':
parser.add_argument('--func', type=str, default='create_waymo_infos', help='') parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='') parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
parser.add_argument('--update_info_only', action='store_true', default=False, help='') parser.add_argument('--update_info_only', action='store_true', default=False, help='')
parser.add_argument('--use_parallel', action='store_true', default=False, help='')
args = parser.parse_args() args = parser.parse_args()
...@@ -599,7 +720,8 @@ if __name__ == '__main__': ...@@ -599,7 +720,8 @@ if __name__ == '__main__':
class_names=['Vehicle', 'Pedestrian', 'Cyclist'], class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
data_path=ROOT_DIR / 'data' / 'waymo', data_path=ROOT_DIR / 'data' / 'waymo',
save_path=ROOT_DIR / 'data' / 'waymo', save_path=ROOT_DIR / 'data' / 'waymo',
processed_data_tag=args.processed_data_tag processed_data_tag=args.processed_data_tag,
use_parallel=args.use_parallel
) )
else: else:
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment