Commit 4996eb46 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: waymo_eval bug when training with speed

parent 1541d269
......@@ -272,7 +272,7 @@ class WaymoDataset(DatasetTemplate):
assert gt_boxes_lidar.shape[-1] == 9
else:
gt_boxes_lidar = gt_boxes_lidar[:, 0:7]
if self.training and self.dataset_cfg.get('FILTER_EMPTY_BOXES_FOR_TRAIN', False):
mask = (annos['num_points_in_gt'] > 0) # filter empty boxes
annos['name'] = annos['name'][mask]
......@@ -290,14 +290,13 @@ class WaymoDataset(DatasetTemplate):
data_dict.pop('num_points_in_gt', None)
return data_dict
@staticmethod
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
def generate_prediction_dicts(self, batch_dict, pred_dicts, class_names, output_path=None):
"""
Args:
batch_dict:
frame_id:
pred_dicts: list of pred_dicts
pred_boxes: (N, 7), Tensor
pred_boxes: (N, 7 or 9), Tensor
pred_scores: (N), Tensor
pred_labels: (N), Tensor
class_names:
......@@ -308,9 +307,10 @@ class WaymoDataset(DatasetTemplate):
"""
def get_template_prediction(num_samples):
box_dim = 9 if self.dataset_cfg.get('TRAIN_WITH_SPEED', False) else 7
ret_dict = {
'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
'boxes_lidar': np.zeros([num_samples, 7])
'boxes_lidar': np.zeros([num_samples, box_dim])
}
return ret_dict
......
......@@ -60,13 +60,18 @@ class OpenPCDetWaymoDetectionMetricsEstimator(tf.test.TestCase):
if fake_gt_infos:
info['gt_boxes_lidar'] = boxes3d_kitti_fakelidar_to_lidar(info['gt_boxes_lidar'])
boxes3d.append(info['gt_boxes_lidar'][box_mask])
if info['gt_boxes_lidar'].shape[-1] == 9:
boxes3d.append(info['gt_boxes_lidar'][box_mask][:, 0:7])
else:
boxes3d.append(info['gt_boxes_lidar'][box_mask])
else:
num_boxes = len(info['boxes_lidar'])
difficulty.append([0] * num_boxes)
score.append(info['score'])
boxes3d.append(np.array(info['boxes_lidar']))
box_name = info['name']
if boxes3d[-1].shape[-1] == 9:
boxes3d[-1] = boxes3d[-1][", 0:7"]
obj_type += [self.WAYMO_CLASSES.index(name) for i, name in enumerate(box_name)]
frame_id.append(np.array([frame_index] * num_boxes))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment