Commit 8eeb8136 authored by ChaimZhu's avatar ChaimZhu Committed by ZwwWayne
Browse files

[Fix] fix kitti evaluation bugs on test dataset (#2005)

* fix kitti test evaluation bugs

* update
parent 50ac8884
...@@ -113,7 +113,7 @@ class KittiDataset(Det3DDataset): ...@@ -113,7 +113,7 @@ class KittiDataset(Det3DDataset):
info['plane'] = plane_lidar info['plane'] = plane_lidar
if self.task == 'mono_det': if self.task == 'mono_det' and self.load_eval_anns:
info['instances'] = info['cam_instances'][self.default_cam_key] info['instances'] = info['cam_instances'][self.default_cam_key]
info = super().parse_data_info(info) info = super().parse_data_info(info)
......
...@@ -36,6 +36,10 @@ class KittiMetric(BaseMetric): ...@@ -36,6 +36,10 @@ class KittiMetric(BaseMetric):
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
default_cam_key (str, optional): The default camera for lidar to default_cam_key (str, optional): The default camera for lidar to
camear conversion. By default, KITTI: CAM2, Waymo: CAM_FRONT camear conversion. By default, KITTI: CAM2, Waymo: CAM_FRONT
format_only (bool): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server.
Defaults to False.
submission_prefix (str, optional): The prefix of submission data. submission_prefix (str, optional): The prefix of submission data.
If not specified, the submission data will not be generated. If not specified, the submission data will not be generated.
Default: None. Default: None.
...@@ -52,6 +56,7 @@ class KittiMetric(BaseMetric): ...@@ -52,6 +56,7 @@ class KittiMetric(BaseMetric):
prefix: Optional[str] = None, prefix: Optional[str] = None,
pklfile_prefix: str = None, pklfile_prefix: str = None,
default_cam_key: str = 'CAM2', default_cam_key: str = 'CAM2',
format_only: bool = False,
submission_prefix: str = None, submission_prefix: str = None,
collect_device: str = 'cpu', collect_device: str = 'cpu',
file_client_args: dict = dict(backend='disk')): file_client_args: dict = dict(backend='disk')):
...@@ -61,6 +66,13 @@ class KittiMetric(BaseMetric): ...@@ -61,6 +66,13 @@ class KittiMetric(BaseMetric):
self.pcd_limit_range = pcd_limit_range self.pcd_limit_range = pcd_limit_range
self.ann_file = ann_file self.ann_file = ann_file
self.pklfile_prefix = pklfile_prefix self.pklfile_prefix = pklfile_prefix
self.format_only = format_only
if self.format_only:
assert submission_prefix is not None, 'submission_prefix must be'
'not None when format_only is True, otherwise the result files'
'will be saved to a temp directory which will be cleaned up at'
'the end.'
self.submission_prefix = submission_prefix self.submission_prefix = submission_prefix
self.pred_box_type_3d = pred_box_type_3d self.pred_box_type_3d = pred_box_type_3d
self.default_cam_key = default_cam_key self.default_cam_key = default_cam_key
...@@ -84,49 +96,52 @@ class KittiMetric(BaseMetric): ...@@ -84,49 +96,52 @@ class KittiMetric(BaseMetric):
Returns: Returns:
List[dict]: List of Kitti annotations. List[dict]: List of Kitti annotations.
""" """
cat2label = data_infos['metainfo']['categories']
data_annos = data_infos['data_list'] data_annos = data_infos['data_list']
label2cat = dict((v, k) for (k, v) in cat2label.items()) if not self.format_only:
assert 'instances' in data_annos[0] cat2label = data_infos['metainfo']['categories']
for i, annos in enumerate(data_annos): label2cat = dict((v, k) for (k, v) in cat2label.items())
if len(annos['instances']) == 0: assert 'instances' in data_annos[0]
kitti_annos = { for i, annos in enumerate(data_annos):
'name': np.array([]), if len(annos['instances']) == 0:
'truncated': np.array([]), kitti_annos = {
'occluded': np.array([]), 'name': np.array([]),
'alpha': np.array([]), 'truncated': np.array([]),
'bbox': np.zeros([0, 4]), 'occluded': np.array([]),
'dimensions': np.zeros([0, 3]), 'alpha': np.array([]),
'location': np.zeros([0, 3]), 'bbox': np.zeros([0, 4]),
'rotation_y': np.array([]), 'dimensions': np.zeros([0, 3]),
'score': np.array([]), 'location': np.zeros([0, 3]),
} 'rotation_y': np.array([]),
else: 'score': np.array([]),
kitti_annos = { }
'name': [], else:
'truncated': [], kitti_annos = {
'occluded': [], 'name': [],
'alpha': [], 'truncated': [],
'bbox': [], 'occluded': [],
'location': [], 'alpha': [],
'dimensions': [], 'bbox': [],
'rotation_y': [], 'location': [],
'score': [] 'dimensions': [],
} 'rotation_y': [],
for instance in annos['instances']: 'score': []
label = instance['bbox_label'] }
kitti_annos['name'].append(label2cat[label]) for instance in annos['instances']:
kitti_annos['truncated'].append(instance['truncated']) label = instance['bbox_label']
kitti_annos['occluded'].append(instance['occluded']) kitti_annos['name'].append(label2cat[label])
kitti_annos['alpha'].append(instance['alpha']) kitti_annos['truncated'].append(instance['truncated'])
kitti_annos['bbox'].append(instance['bbox']) kitti_annos['occluded'].append(instance['occluded'])
kitti_annos['location'].append(instance['bbox_3d'][:3]) kitti_annos['alpha'].append(instance['alpha'])
kitti_annos['dimensions'].append(instance['bbox_3d'][3:6]) kitti_annos['bbox'].append(instance['bbox'])
kitti_annos['rotation_y'].append(instance['bbox_3d'][6]) kitti_annos['location'].append(instance['bbox_3d'][:3])
kitti_annos['score'].append(instance['score']) kitti_annos['dimensions'].append(
for name in kitti_annos: instance['bbox_3d'][3:6])
kitti_annos[name] = np.array(kitti_annos[name]) kitti_annos['rotation_y'].append(
data_annos[i]['kitti_annos'] = kitti_annos instance['bbox_3d'][6])
kitti_annos['score'].append(instance['score'])
for name in kitti_annos:
kitti_annos[name] = np.array(kitti_annos[name])
data_annos[i]['kitti_annos'] = kitti_annos
return data_annos return data_annos
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
...@@ -178,12 +193,18 @@ class KittiMetric(BaseMetric): ...@@ -178,12 +193,18 @@ class KittiMetric(BaseMetric):
submission_prefix=self.submission_prefix, submission_prefix=self.submission_prefix,
classes=self.classes) classes=self.classes)
metric_dict = {}
if self.format_only:
logger.info('results are saved in '
f'{osp.dirname(self.submission_prefix)}')
return metric_dict
gt_annos = [ gt_annos = [
self.data_infos[result['sample_idx']]['kitti_annos'] self.data_infos[result['sample_idx']]['kitti_annos']
for result in results for result in results
] ]
metric_dict = {}
for metric in self.metrics: for metric in self.metrics:
ap_dict = self.kitti_evaluate( ap_dict = self.kitti_evaluate(
result_dict, result_dict,
...@@ -321,7 +342,7 @@ class KittiMetric(BaseMetric): ...@@ -321,7 +342,7 @@ class KittiMetric(BaseMetric):
mmengine.mkdir_or_exist(submission_prefix) mmengine.mkdir_or_exist(submission_prefix)
det_annos = [] det_annos = []
print('\nConverting prediction to KITTI format') print('\nConverting 3D prediction to KITTI format')
for idx, pred_dicts in enumerate( for idx, pred_dicts in enumerate(
mmengine.track_iter_progress(net_outputs)): mmengine.track_iter_progress(net_outputs)):
annos = [] annos = []
...@@ -447,7 +468,7 @@ class KittiMetric(BaseMetric): ...@@ -447,7 +468,7 @@ class KittiMetric(BaseMetric):
assert len(net_outputs) == len(self.data_infos), \ assert len(net_outputs) == len(self.data_infos), \
'invalid list length of network outputs' 'invalid list length of network outputs'
det_annos = [] det_annos = []
print('\nConverting prediction to KITTI format') print('\nConverting 2D prediction to KITTI format')
for i, bboxes_per_sample in enumerate( for i, bboxes_per_sample in enumerate(
mmengine.track_iter_progress(net_outputs)): mmengine.track_iter_progress(net_outputs)):
annos = [] annos = []
...@@ -516,7 +537,7 @@ class KittiMetric(BaseMetric): ...@@ -516,7 +537,7 @@ class KittiMetric(BaseMetric):
mmengine.mkdir_or_exist(submission_prefix) mmengine.mkdir_or_exist(submission_prefix)
print(f'Saving KITTI submission to {submission_prefix}') print(f'Saving KITTI submission to {submission_prefix}')
for i, anno in enumerate(det_annos): for i, anno in enumerate(det_annos):
sample_idx = self.data_infos[i]['image']['image_idx'] sample_idx = sample_id_list[i]
cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt' cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f: with open(cur_det_file, 'w') as f:
bbox = anno['bbox'] bbox = anno['bbox']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment