Commit 4996eb46 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: waymo_eval bug when training with speed

parent 1541d269
...@@ -290,14 +290,13 @@ class WaymoDataset(DatasetTemplate): ...@@ -290,14 +290,13 @@ class WaymoDataset(DatasetTemplate):
data_dict.pop('num_points_in_gt', None) data_dict.pop('num_points_in_gt', None)
return data_dict return data_dict
@staticmethod def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
""" """
Args: Args:
batch_dict: batch_dict:
frame_id: frame_id:
pred_dicts: list of pred_dicts pred_dicts: list of pred_dicts
pred_boxes: (N, 7), Tensor pred_boxes: (N, 7 or 9), Tensor
pred_scores: (N), Tensor pred_scores: (N), Tensor
pred_labels: (N), Tensor pred_labels: (N), Tensor
class_names: class_names:
...@@ -308,9 +307,10 @@ class WaymoDataset(DatasetTemplate): ...@@ -308,9 +307,10 @@ class WaymoDataset(DatasetTemplate):
""" """
def get_template_prediction(num_samples): def get_template_prediction(num_samples):
box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7
ret_dict = { ret_dict = {
'name': np.zeros(num_samples), 'score': np.zeros(num_samples), 'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
'boxes_lidar': np.zeros([num_samples, 7]) 'boxes_lidar': np.zeros([num_samples, box_dim])
} }
return ret_dict return ret_dict
......
...@@ -60,6 +60,9 @@ class OpenPCDetWaymoDetectionMetricsEstimator(tf.test.TestCase): ...@@ -60,6 +60,9 @@ class OpenPCDetWaymoDetectionMetricsEstimator(tf.test.TestCase):
if fake_gt_infos: if fake_gt_infos:
info['gt_boxes_lidar'] = boxes3d_kitti_fakelidar_to_lidar(info['gt_boxes_lidar']) info['gt_boxes_lidar'] = boxes3d_kitti_fakelidar_to_lidar(info['gt_boxes_lidar'])
if info['gt_boxes_lidar'].shape[-1] == 9:
boxes3d.append(info['gt_boxes_lidar'][box_mask][:, 0:7])
else:
boxes3d.append(info['gt_boxes_lidar'][box_mask]) boxes3d.append(info['gt_boxes_lidar'][box_mask])
else: else:
num_boxes = len(info['boxes_lidar']) num_boxes = len(info['boxes_lidar'])
...@@ -67,6 +70,8 @@ class OpenPCDetWaymoDetectionMetricsEstimator(tf.test.TestCase): ...@@ -67,6 +70,8 @@ class OpenPCDetWaymoDetectionMetricsEstimator(tf.test.TestCase):
score.append(info['score']) score.append(info['score'])
boxes3d.append(np.array(info['boxes_lidar'])) boxes3d.append(np.array(info['boxes_lidar']))
box_name = info['name'] box_name = info['name']
if boxes3d[-1].shape[-1] == 9:
boxes3d[-1] = boxes3d[-1][", 0:7"]
obj_type += [self.WAYMO_CLASSES.index(name) for i, name in enumerate(box_name)] obj_type += [self.WAYMO_CLASSES.index(name) for i, name in enumerate(box_name)]
frame_id.append(np.array([frame_index] * num_boxes)) frame_id.append(np.array([frame_index] * num_boxes))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment