Commit 96ba76a3 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

database should save with np.float32

parent 7e2d56b4
...@@ -225,7 +225,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -225,7 +225,7 @@ class WaymoDataset(DatasetTemplate):
points_pre = remove_ego_points(points_pre, 1.0) points_pre = remove_ego_points(points_pre, 1.0)
points_pre_all.append(points_pre) points_pre_all.append(points_pre)
num_points_pre.append(points_pre.shape[0]) num_points_pre.append(points_pre.shape[0])
points = np.concatenate([points] + points_pre_all, axis=0) points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int) num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int)
return points, num_points_all, sample_idx_pre_list return points, num_points_all, sample_idx_pre_list
...@@ -461,6 +461,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -461,6 +461,8 @@ class WaymoDataset(DatasetTemplate):
gt_points[:, :3] -= gt_boxes[i, :3] gt_points[:, :3] -= gt_boxes[i, :3]
if (used_classes is None) or names[i] in used_classes: if (used_classes is None) or names[i] in used_classes:
gt_points = gt_points.astype(np.float32)
assert gt_points.dtype == np.float32
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
gt_points.tofile(f) gt_points.tofile(f)
...@@ -491,7 +493,7 @@ class WaymoDataset(DatasetTemplate): ...@@ -491,7 +493,7 @@ class WaymoDataset(DatasetTemplate):
def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False): def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None, total_samples=0, use_cuda=False):
info, info_idx = info_with_idx info, info_idx = info_with_idx
print('gt_database sample: %d/%d' % (info_idx, total_samples)) print('gt_database sample: %d/%d' % (info_idx, total_samples))
all_db_infos = {} all_db_infos = {}
pc_info = info['point_cloud'] pc_info = info['point_cloud']
sequence_name = pc_info['lidar_sequence'] sequence_name = pc_info['lidar_sequence']
...@@ -542,10 +544,12 @@ class WaymoDataset(DatasetTemplate): ...@@ -542,10 +544,12 @@ class WaymoDataset(DatasetTemplate):
gt_points = points[box_idxs_of_pts == i] gt_points = points[box_idxs_of_pts == i]
else: else:
gt_points = points[box_point_mask[i] > 0] gt_points = points[box_point_mask[i] > 0]
gt_points[:, :3] -= gt_boxes[i, :3] gt_points[:, :3] -= gt_boxes[i, :3]
if (used_classes is None) or names[i] in used_classes: if (used_classes is None) or names[i] in used_classes:
gt_points = gt_points.astype(np.float32)
assert gt_points.dtype == np.float32
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
gt_points.tofile(f) gt_points.tofile(f)
...@@ -581,8 +585,8 @@ class WaymoDataset(DatasetTemplate): ...@@ -581,8 +585,8 @@ class WaymoDataset(DatasetTemplate):
print(f'Number workers: {num_workers}') print(f'Number workers: {num_workers}')
create_gt_database_of_single_scene = partial( create_gt_database_of_single_scene = partial(
self.create_gt_database_of_single_scene, self.create_gt_database_of_single_scene,
use_sequence_data=use_sequence_data, database_save_path=database_save_path, use_sequence_data=use_sequence_data, database_save_path=database_save_path,
used_classes=used_classes, total_samples=len(infos), use_cuda=False used_classes=used_classes, total_samples=len(infos), use_cuda=False
) )
# create_gt_database_of_single_scene((infos[0], 0)) # create_gt_database_of_single_scene((infos[0], 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment