Commit 49f06039 authored by zhangwenwei's avatar zhangwenwei
Browse files

Fix bug caused by evalhook API change

parent 148fea12
......@@ -191,7 +191,7 @@ log_config = dict(
# yapf:enable
# runtime settings
total_epochs = 80
dist_params = dict(backend='nccl')
dist_params = dict(backend='nccl', port=29511)
log_level = 'INFO'
work_dir = './work_dirs/sec_secfpn_80e'
load_from = None
......
import copy
import os
import pickle
import os.path as osp
import tempfile
import mmcv
import numpy as np
......@@ -44,7 +45,7 @@ class KittiDataset(torch_data.Dataset):
self.ann_file = ann_file
with open(ann_file, 'rb') as f:
self.kitti_infos = pickle.load(f)
self.kitti_infos = mmcv.load(f)
# set group flag for the sampler
if not self.test_mode:
......@@ -262,34 +263,73 @@ class KittiDataset(torch_data.Dataset):
inds = np.array(inds, dtype=np.int64)
return inds
def reformat_bbox(self, outputs, out=None):
def format_results(self,
outputs,
pklfile_prefix=None,
submission_prefix=None):
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
if not isinstance(outputs[0][0], dict):
sample_idx = [
info['image']['image_idx'] for info in self.kitti_infos
]
result_files = self.bbox2result_kitti2d(outputs, self.class_names,
sample_idx, out)
sample_idx, pklfile_prefix,
submission_prefix)
else:
result_files = self.bbox2result_kitti(outputs, self.class_names,
out)
pklfile_prefix,
submission_prefix)
return result_files
def evaluate(self, result_files, eval_types=None):
def evaluate(self,
results,
metric=None,
logger=None,
pklfile_prefix=None,
submission_prefix=None,
result_names=['pts_bbox']):
"""Evaluation in KITTI protocol.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas.
If not specified, the submission data will not be generated.
Returns:
dict[str: float]
"""
result_files, tmp_dir = self.format_results(results, pklfile_prefix)
from mmdet3d.core.evaluation import kitti_eval
gt_annos = [info['annos'] for info in self.kitti_infos]
if eval_types == 'img_bbox':
if metric == 'img_bbox':
ap_result_str, ap_dict = kitti_eval(
gt_annos, result_files, self.class_names, eval_types=['bbox'])
else:
ap_result_str, ap_dict = kitti_eval(gt_annos, result_files,
self.class_names)
return ap_result_str, ap_dict
def bbox2result_kitti(self, net_outputs, class_names, out=None):
if out:
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out
result_dir = output_dir + '/data'
mmcv.mkdir_or_exist(result_dir)
if tmp_dir is not None:
tmp_dir.cleanup()
return ap_dict
def bbox2result_kitti(self,
net_outputs,
class_names,
pklfile_prefix=None,
submission_prefix=None):
if submission_prefix is not None:
mmcv.mkdir_or_exist(submission_prefix)
det_annos = []
print('Converting prediction to KITTI format')
......@@ -346,9 +386,9 @@ class KittiDataset(torch_data.Dataset):
anno = {k: np.stack(v) for k, v in anno.items()}
annos.append(anno)
if out:
cur_det_file = result_dir + '/%06d.txt' % sample_idx
with open(cur_det_file, 'w') as f:
if submission_prefix is not None:
curr_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(curr_file, 'w') as f:
bbox = anno['bbox']
loc = anno['location']
dims = anno['dimensions'] # lhw -> hwl
......@@ -386,9 +426,9 @@ class KittiDataset(torch_data.Dataset):
det_annos += annos
if out:
if not out.endswith(('.pkl', '.pickle')):
out = '{}.pkl'.format(out)
if pklfile_prefix is not None:
if not pklfile_prefix.endswith(('.pkl', '.pickle')):
out = f'{pklfile_prefix}.pkl'
mmcv.dump(det_annos, out)
print('Result is saved to %s' % out)
......@@ -398,7 +438,8 @@ class KittiDataset(torch_data.Dataset):
net_outputs,
class_names,
sample_ids,
out=None):
pklfile_prefix=None,
submission_prefix=None):
"""Convert results to kitti format for evaluation and test submission
Args:
......@@ -406,6 +447,8 @@ class KittiDataset(torch_data.Dataset):
class_nanes (List[String]): A list of class names
sample_idx (List[Int]): A list of samples' index,
should have the same length as net_outputs.
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Return:
List([dict]): A list of dict have the kitti format
......@@ -469,17 +512,20 @@ class KittiDataset(torch_data.Dataset):
[sample_idx] * num_example, dtype=np.int64)
det_annos += annos
if out:
if pklfile_prefix is not None:
# save file in pkl format
pklfile_path = (
pklfile_prefix[:-4] if pklfile_prefix.endswith(
('.pkl', '.pickle')) else pklfile_prefix)
mmcv.dump(det_annos, pklfile_path)
if submission_prefix is not None:
# save file in submission format
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out
result_dir = output_dir + '/data'
mmcv.mkdir_or_exist(result_dir)
out = '{}.pkl'.format(result_dir)
mmcv.dump(det_annos, out)
print('Result is saved to {}'.format(out))
mmcv.mkdir_or_exist(submission_prefix)
print(f'Saving KITTI submission to {submission_prefix}')
for i, anno in enumerate(det_annos):
sample_idx = sample_ids[i]
cur_det_file = result_dir + '/%06d.txt' % sample_idx
cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f:
bbox = anno['bbox']
loc = anno['location']
......@@ -497,7 +543,7 @@ class KittiDataset(torch_data.Dataset):
anno['score'][idx]),
file=f,
)
print('Result is saved to {}'.format(result_dir))
print('Result is saved to {}'.format(submission_prefix))
return det_annos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment