Commit b32fbddb authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: inference and generate results on KITTI test set

parent 12b8c131
...@@ -101,7 +101,7 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -101,7 +101,7 @@ class DatasetTemplate(torch_data.Dataset):
data_dict = self.data_processor.forward( data_dict = self.data_processor.forward(
data_dict=data_dict data_dict=data_dict
) )
data_dict.pop('gt_names') data_dict.pop('gt_names', None)
return data_dict return data_dict
......
...@@ -318,11 +318,10 @@ class KittiDataset(DatasetTemplate): ...@@ -318,11 +318,10 @@ class KittiDataset(DatasetTemplate):
return annos return annos
def evaluation(self, det_annos, class_names, **kwargs): def evaluation(self, det_annos, class_names, **kwargs):
assert 'annos' in self.kitti_infos[0].keys() if 'annos' not in self.kitti_infos[0].keys():
from .kitti_object_eval_python import eval as kitti_eval return None, {}
if 'annos' not in self.kitti_infos[0]: from .kitti_object_eval_python import eval as kitti_eval
return 'None', {}
eval_det_annos = copy.deepcopy(det_annos) eval_det_annos = copy.deepcopy(det_annos)
eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.kitti_infos] eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.kitti_infos]
......
...@@ -8,9 +8,9 @@ from pcdet.utils import common_utils ...@@ -8,9 +8,9 @@ from pcdet.utils import common_utils
def statistics_info(cfg, ret_dict, metric, disp_dict): def statistics_info(cfg, ret_dict, metric, disp_dict):
for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST: for cur_thresh in cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST:
metric['recall_roi_%s' % str(cur_thresh)] += ret_dict['roi_%s' % str(cur_thresh)] metric['recall_roi_%s' % str(cur_thresh)] += ret_dict.get('roi_%s' % str(cur_thresh), 0)
metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict['rcnn_%s' % str(cur_thresh)] metric['recall_rcnn_%s' % str(cur_thresh)] += ret_dict.get('rcnn_%s' % str(cur_thresh), 0)
metric['gt_num'] += ret_dict['gt'] metric['gt_num'] += ret_dict.get('gt', 0)
min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0] min_thresh = cfg.MODEL.POST_PROCESSING.RECALL_THRESH_LIST[0]
disp_dict['recall_%s' % str(min_thresh)] = \ disp_dict['recall_%s' % str(min_thresh)] = \
'(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num']) '(%d, %d) / %d' % (metric['recall_roi_%s' % str(min_thresh)], metric['recall_rcnn_%s' % str(min_thresh)], metric['gt_num'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment