Commit 49f06039 authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix bug caused by evalhook API change

parent 148fea12
...@@ -191,7 +191,7 @@ log_config = dict( ...@@ -191,7 +191,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 80 total_epochs = 80
dist_params = dict(backend='nccl') dist_params = dict(backend='nccl', port=29511)
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/sec_secfpn_80e' work_dir = './work_dirs/sec_secfpn_80e'
load_from = None load_from = None
......
import copy import copy
import os import os
import pickle import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
...@@ -44,7 +45,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -44,7 +45,7 @@ class KittiDataset(torch_data.Dataset):
self.ann_file = ann_file self.ann_file = ann_file
with open(ann_file, 'rb') as f: with open(ann_file, 'rb') as f:
self.kitti_infos = pickle.load(f) self.kitti_infos = mmcv.load(f)
# set group flag for the sampler # set group flag for the sampler
if not self.test_mode: if not self.test_mode:
...@@ -262,34 +263,73 @@ class KittiDataset(torch_data.Dataset): ...@@ -262,34 +263,73 @@ class KittiDataset(torch_data.Dataset):
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
def reformat_bbox(self, outputs, out=None): def format_results(self,
outputs,
pklfile_prefix=None,
submission_prefix=None):
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
if not isinstance(outputs[0][0], dict): if not isinstance(outputs[0][0], dict):
sample_idx = [ sample_idx = [
info['image']['image_idx'] for info in self.kitti_infos info['image']['image_idx'] for info in self.kitti_infos
] ]
result_files = self.bbox2result_kitti2d(outputs, self.class_names, result_files = self.bbox2result_kitti2d(outputs, self.class_names,
sample_idx, out) sample_idx, pklfile_prefix,
submission_prefix)
else: else:
result_files = self.bbox2result_kitti(outputs, self.class_names, result_files = self.bbox2result_kitti(outputs, self.class_names,
out) pklfile_prefix,
submission_prefix)
return result_files return result_files
def evaluate(self, result_files, eval_types=None): def evaluate(self,
results,
metric=None,
logger=None,
pklfile_prefix=None,
submission_prefix=None,
result_names=['pts_bbox']):
"""Evaluation in KITTI protocol.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas.
If not specified, the submission data will not be generated.
Returns:
dict[str: float]
"""
result_files, tmp_dir = self.format_results(results, pklfile_prefix)
from mmdet3d.core.evaluation import kitti_eval from mmdet3d.core.evaluation import kitti_eval
gt_annos = [info['annos'] for info in self.kitti_infos] gt_annos = [info['annos'] for info in self.kitti_infos]
if eval_types == 'img_bbox': if metric == 'img_bbox':
ap_result_str, ap_dict = kitti_eval( ap_result_str, ap_dict = kitti_eval(
gt_annos, result_files, self.class_names, eval_types=['bbox']) gt_annos, result_files, self.class_names, eval_types=['bbox'])
else: else:
ap_result_str, ap_dict = kitti_eval(gt_annos, result_files, ap_result_str, ap_dict = kitti_eval(gt_annos, result_files,
self.class_names) self.class_names)
return ap_result_str, ap_dict
def bbox2result_kitti(self, net_outputs, class_names, out=None): if tmp_dir is not None:
if out: tmp_dir.cleanup()
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out return ap_dict
result_dir = output_dir + '/data'
mmcv.mkdir_or_exist(result_dir) def bbox2result_kitti(self,
net_outputs,
class_names,
pklfile_prefix=None,
submission_prefix=None):
if submission_prefix is not None:
mmcv.mkdir_or_exist(submission_prefix)
det_annos = [] det_annos = []
print('Converting prediction to KITTI format') print('Converting prediction to KITTI format')
...@@ -346,9 +386,9 @@ class KittiDataset(torch_data.Dataset): ...@@ -346,9 +386,9 @@ class KittiDataset(torch_data.Dataset):
anno = {k: np.stack(v) for k, v in anno.items()} anno = {k: np.stack(v) for k, v in anno.items()}
annos.append(anno) annos.append(anno)
if out: if submission_prefix is not None:
cur_det_file = result_dir + '/%06d.txt' % sample_idx curr_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f: with open(curr_file, 'w') as f:
bbox = anno['bbox'] bbox = anno['bbox']
loc = anno['location'] loc = anno['location']
dims = anno['dimensions'] # lhw -> hwl dims = anno['dimensions'] # lhw -> hwl
...@@ -386,9 +426,9 @@ class KittiDataset(torch_data.Dataset): ...@@ -386,9 +426,9 @@ class KittiDataset(torch_data.Dataset):
det_annos += annos det_annos += annos
if out: if pklfile_prefix is not None:
if not out.endswith(('.pkl', '.pickle')): if not pklfile_prefix.endswith(('.pkl', '.pickle')):
out = '{}.pkl'.format(out) out = f'{pklfile_prefix}.pkl'
mmcv.dump(det_annos, out) mmcv.dump(det_annos, out)
print('Result is saved to %s' % out) print('Result is saved to %s' % out)
...@@ -398,7 +438,8 @@ class KittiDataset(torch_data.Dataset): ...@@ -398,7 +438,8 @@ class KittiDataset(torch_data.Dataset):
net_outputs, net_outputs,
class_names, class_names,
sample_ids, sample_ids,
out=None): pklfile_prefix=None,
submission_prefix=None):
"""Convert results to kitti format for evaluation and test submission """Convert results to kitti format for evaluation and test submission
Args: Args:
...@@ -406,6 +447,8 @@ class KittiDataset(torch_data.Dataset): ...@@ -406,6 +447,8 @@ class KittiDataset(torch_data.Dataset):
class_nanes (List[String]): A list of class names class_nanes (List[String]): A list of class names
sample_idx (List[Int]): A list of samples' index, sample_idx (List[Int]): A list of samples' index,
should have the same length as net_outputs. should have the same length as net_outputs.
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Return: Return:
List([dict]): A list of dict have the kitti format List([dict]): A list of dict have the kitti format
...@@ -469,17 +512,20 @@ class KittiDataset(torch_data.Dataset): ...@@ -469,17 +512,20 @@ class KittiDataset(torch_data.Dataset):
[sample_idx] * num_example, dtype=np.int64) [sample_idx] * num_example, dtype=np.int64)
det_annos += annos det_annos += annos
if out: if pklfile_prefix is not None:
# save file in pkl format
pklfile_path = (
pklfile_prefix[:-4] if pklfile_prefix.endswith(
('.pkl', '.pickle')) else pklfile_prefix)
mmcv.dump(det_annos, pklfile_path)
if submission_prefix is not None:
# save file in submission format # save file in submission format
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out mmcv.mkdir_or_exist(submission_prefix)
result_dir = output_dir + '/data' print(f'Saving KITTI submission to {submission_prefix}')
mmcv.mkdir_or_exist(result_dir)
out = '{}.pkl'.format(result_dir)
mmcv.dump(det_annos, out)
print('Result is saved to {}'.format(out))
for i, anno in enumerate(det_annos): for i, anno in enumerate(det_annos):
sample_idx = sample_ids[i] sample_idx = sample_ids[i]
cur_det_file = result_dir + '/%06d.txt' % sample_idx cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f: with open(cur_det_file, 'w') as f:
bbox = anno['bbox'] bbox = anno['bbox']
loc = anno['location'] loc = anno['location']
...@@ -497,7 +543,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -497,7 +543,7 @@ class KittiDataset(torch_data.Dataset):
anno['score'][idx]), anno['score'][idx]),
file=f, file=f,
) )
print('Result is saved to {}'.format(result_dir)) print('Result is saved to {}'.format(submission_prefix))
return det_annos return det_annos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment