Commit 8eeb8136 authored by ChaimZhu's avatar ChaimZhu Committed by ZwwWayne
Browse files

[Fix] fix kitti evaluation bugs on test dataset (#2005)

* fix kitti test evaluation bugs

* update
parent 50ac8884
......@@ -113,7 +113,7 @@ class KittiDataset(Det3DDataset):
info['plane'] = plane_lidar
if self.task == 'mono_det':
if self.task == 'mono_det' and self.load_eval_anns:
info['instances'] = info['cam_instances'][self.default_cam_key]
info = super().parse_data_info(info)
......
......@@ -36,6 +36,10 @@ class KittiMetric(BaseMetric):
If not specified, a temp file will be created. Default: None.
default_cam_key (str, optional): The default camera for lidar to
camear conversion. By default, KITTI: CAM2, Waymo: CAM_FRONT
format_only (bool): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server.
Defaults to False.
submission_prefix (str, optional): The prefix of submission data.
If not specified, the submission data will not be generated.
Default: None.
......@@ -52,6 +56,7 @@ class KittiMetric(BaseMetric):
prefix: Optional[str] = None,
pklfile_prefix: str = None,
default_cam_key: str = 'CAM2',
format_only: bool = False,
submission_prefix: str = None,
collect_device: str = 'cpu',
file_client_args: dict = dict(backend='disk')):
......@@ -61,6 +66,13 @@ class KittiMetric(BaseMetric):
self.pcd_limit_range = pcd_limit_range
self.ann_file = ann_file
self.pklfile_prefix = pklfile_prefix
self.format_only = format_only
if self.format_only:
assert submission_prefix is not None, 'submission_prefix must be'
'not None when format_only is True, otherwise the result files'
'will be saved to a temp directory which will be cleaned up at'
'the end.'
self.submission_prefix = submission_prefix
self.pred_box_type_3d = pred_box_type_3d
self.default_cam_key = default_cam_key
......@@ -84,49 +96,52 @@ class KittiMetric(BaseMetric):
Returns:
List[dict]: List of Kitti annotations.
"""
cat2label = data_infos['metainfo']['categories']
data_annos = data_infos['data_list']
label2cat = dict((v, k) for (k, v) in cat2label.items())
assert 'instances' in data_annos[0]
for i, annos in enumerate(data_annos):
if len(annos['instances']) == 0:
kitti_annos = {
'name': np.array([]),
'truncated': np.array([]),
'occluded': np.array([]),
'alpha': np.array([]),
'bbox': np.zeros([0, 4]),
'dimensions': np.zeros([0, 3]),
'location': np.zeros([0, 3]),
'rotation_y': np.array([]),
'score': np.array([]),
}
else:
kitti_annos = {
'name': [],
'truncated': [],
'occluded': [],
'alpha': [],
'bbox': [],
'location': [],
'dimensions': [],
'rotation_y': [],
'score': []
}
for instance in annos['instances']:
label = instance['bbox_label']
kitti_annos['name'].append(label2cat[label])
kitti_annos['truncated'].append(instance['truncated'])
kitti_annos['occluded'].append(instance['occluded'])
kitti_annos['alpha'].append(instance['alpha'])
kitti_annos['bbox'].append(instance['bbox'])
kitti_annos['location'].append(instance['bbox_3d'][:3])
kitti_annos['dimensions'].append(instance['bbox_3d'][3:6])
kitti_annos['rotation_y'].append(instance['bbox_3d'][6])
kitti_annos['score'].append(instance['score'])
for name in kitti_annos:
kitti_annos[name] = np.array(kitti_annos[name])
data_annos[i]['kitti_annos'] = kitti_annos
if not self.format_only:
cat2label = data_infos['metainfo']['categories']
label2cat = dict((v, k) for (k, v) in cat2label.items())
assert 'instances' in data_annos[0]
for i, annos in enumerate(data_annos):
if len(annos['instances']) == 0:
kitti_annos = {
'name': np.array([]),
'truncated': np.array([]),
'occluded': np.array([]),
'alpha': np.array([]),
'bbox': np.zeros([0, 4]),
'dimensions': np.zeros([0, 3]),
'location': np.zeros([0, 3]),
'rotation_y': np.array([]),
'score': np.array([]),
}
else:
kitti_annos = {
'name': [],
'truncated': [],
'occluded': [],
'alpha': [],
'bbox': [],
'location': [],
'dimensions': [],
'rotation_y': [],
'score': []
}
for instance in annos['instances']:
label = instance['bbox_label']
kitti_annos['name'].append(label2cat[label])
kitti_annos['truncated'].append(instance['truncated'])
kitti_annos['occluded'].append(instance['occluded'])
kitti_annos['alpha'].append(instance['alpha'])
kitti_annos['bbox'].append(instance['bbox'])
kitti_annos['location'].append(instance['bbox_3d'][:3])
kitti_annos['dimensions'].append(
instance['bbox_3d'][3:6])
kitti_annos['rotation_y'].append(
instance['bbox_3d'][6])
kitti_annos['score'].append(instance['score'])
for name in kitti_annos:
kitti_annos[name] = np.array(kitti_annos[name])
data_annos[i]['kitti_annos'] = kitti_annos
return data_annos
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
......@@ -178,12 +193,18 @@ class KittiMetric(BaseMetric):
submission_prefix=self.submission_prefix,
classes=self.classes)
metric_dict = {}
if self.format_only:
logger.info('results are saved in '
f'{osp.dirname(self.submission_prefix)}')
return metric_dict
gt_annos = [
self.data_infos[result['sample_idx']]['kitti_annos']
for result in results
]
metric_dict = {}
for metric in self.metrics:
ap_dict = self.kitti_evaluate(
result_dict,
......@@ -321,7 +342,7 @@ class KittiMetric(BaseMetric):
mmengine.mkdir_or_exist(submission_prefix)
det_annos = []
print('\nConverting prediction to KITTI format')
print('\nConverting 3D prediction to KITTI format')
for idx, pred_dicts in enumerate(
mmengine.track_iter_progress(net_outputs)):
annos = []
......@@ -447,7 +468,7 @@ class KittiMetric(BaseMetric):
assert len(net_outputs) == len(self.data_infos), \
'invalid list length of network outputs'
det_annos = []
print('\nConverting prediction to KITTI format')
print('\nConverting 2D prediction to KITTI format')
for i, bboxes_per_sample in enumerate(
mmengine.track_iter_progress(net_outputs)):
annos = []
......@@ -516,7 +537,7 @@ class KittiMetric(BaseMetric):
mmengine.mkdir_or_exist(submission_prefix)
print(f'Saving KITTI submission to {submission_prefix}')
for i, anno in enumerate(det_annos):
sample_idx = self.data_infos[i]['image']['image_idx']
sample_idx = sample_id_list[i]
cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f:
bbox = anno['bbox']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment