"...text-generation-inference.git" did not exist on "b70ae0969f11bae03a3c6194fc8c592a1d8a65b3"
Commit b3b20620 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'fix-kitti-eval' into 'master'

Fix bug caused by evalhook API change

See merge request open-mmlab/mmdet.3d!6
parents 4eca6606 2a14897e
...@@ -194,7 +194,7 @@ log_config = dict( ...@@ -194,7 +194,7 @@ log_config = dict(
# yapf:enable # yapf:enable
# runtime settings # runtime settings
total_epochs = 80 total_epochs = 80
dist_params = dict(backend='nccl') dist_params = dict(backend='nccl', port=29511)
log_level = 'INFO' log_level = 'INFO'
work_dir = './work_dirs/sec_secfpn_80e' work_dir = './work_dirs/sec_secfpn_80e'
load_from = None load_from = None
......
...@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5): ...@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
where x0 < x1, y0 < y1, z0 < z1 where x0 < x1, y0 < y1, z0 < z1
""" """
ndim = int(dims.shape[1]) ndim = int(dims.shape[1])
corners_norm = np.stack( corners_norm = torch.from_numpy(
np.unravel_index(np.arange(2**ndim), [2] * ndim), np.stack(np.unravel_index(np.arange(2**ndim), [2] * ndim), axis=1)).to(
axis=1).astype(dims.dtype) device=dims.device, dtype=dims.dtype)
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1 # now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1 # (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# so need to convert to a format which is convenient to do other computing. # so need to convert to a format which is convenient to do other computing.
...@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5): ...@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
corners_norm = corners_norm[[0, 1, 3, 2]] corners_norm = corners_norm[[0, 1, 3, 2]]
elif ndim == 3: elif ndim == 3:
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
corners_norm = corners_norm - np.array(origin, dtype=dims.dtype) corners_norm = corners_norm - dims.new_tensor(origin)
corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape( corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape(
[1, 2**ndim, ndim]) [1, 2**ndim, ndim])
return corners return corners
......
import copy import copy
import os import os
import pickle import os.path as osp
import tempfile
import mmcv import mmcv
import numpy as np import numpy as np
...@@ -43,8 +44,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -43,8 +44,7 @@ class KittiDataset(torch_data.Dataset):
self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0] self.pcd_limit_range = [0, -40, -3, 70.4, 40, 0.0]
self.ann_file = ann_file self.ann_file = ann_file
with open(ann_file, 'rb') as f: self.kitti_infos = mmcv.load(ann_file)
self.kitti_infos = pickle.load(f)
# set group flag for the sampler # set group flag for the sampler
if not self.test_mode: if not self.test_mode:
...@@ -262,37 +262,76 @@ class KittiDataset(torch_data.Dataset): ...@@ -262,37 +262,76 @@ class KittiDataset(torch_data.Dataset):
inds = np.array(inds, dtype=np.int64) inds = np.array(inds, dtype=np.int64)
return inds return inds
def reformat_bbox(self, outputs, out=None): def format_results(self,
outputs,
pklfile_prefix=None,
submission_prefix=None):
if pklfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
pklfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
if not isinstance(outputs[0][0], dict): if not isinstance(outputs[0][0], dict):
sample_idx = [ sample_idx = [
info['image']['image_idx'] for info in self.kitti_infos info['image']['image_idx'] for info in self.kitti_infos
] ]
result_files = self.bbox2result_kitti2d(outputs, self.class_names, result_files = self.bbox2result_kitti2d(outputs, self.class_names,
sample_idx, out) sample_idx, pklfile_prefix,
submission_prefix)
else: else:
result_files = self.bbox2result_kitti(outputs, self.class_names, result_files = self.bbox2result_kitti(outputs, self.class_names,
out) pklfile_prefix,
return result_files submission_prefix)
return result_files, tmp_dir
def evaluate(self,
results,
metric=None,
logger=None,
pklfile_prefix=None,
submission_prefix=None,
result_names=['pts_bbox']):
"""Evaluation in KITTI protocol.
def evaluate(self, result_files, logger=None, eval_types=None): Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas.
If not specified, the submission data will not be generated.
Returns:
dict[str: float]
"""
result_files, tmp_dir = self.format_results(results, pklfile_prefix)
from mmdet3d.core.evaluation import kitti_eval from mmdet3d.core.evaluation import kitti_eval
gt_annos = [info['annos'] for info in self.kitti_infos] gt_annos = [info['annos'] for info in self.kitti_infos]
if eval_types == 'img_bbox': if metric == 'img_bbox':
ap_result_str, ap_dict = kitti_eval( ap_result_str, ap_dict = kitti_eval(
gt_annos, result_files, self.class_names, eval_types=['bbox']) gt_annos, result_files, self.class_names, eval_types=['bbox'])
else: else:
ap_result_str, ap_dict = kitti_eval(gt_annos, result_files, ap_result_str, ap_dict = kitti_eval(gt_annos, result_files,
self.class_names) self.class_names)
return ap_dict
def bbox2result_kitti(self, net_outputs, class_names, out=None): if tmp_dir is not None:
if out: tmp_dir.cleanup()
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out return ap_dict, tmp_dir
result_dir = output_dir + '/data'
mmcv.mkdir_or_exist(result_dir) def bbox2result_kitti(self,
net_outputs,
class_names,
pklfile_prefix=None,
submission_prefix=None):
if submission_prefix is not None:
mmcv.mkdir_or_exist(submission_prefix)
det_annos = [] det_annos = []
print('Converting prediction to KITTI format') print('\nConverting prediction to KITTI format')
for idx, pred_dicts in enumerate( for idx, pred_dicts in enumerate(
mmcv.track_iter_progress(net_outputs)): mmcv.track_iter_progress(net_outputs)):
annos = [] annos = []
...@@ -346,9 +385,9 @@ class KittiDataset(torch_data.Dataset): ...@@ -346,9 +385,9 @@ class KittiDataset(torch_data.Dataset):
anno = {k: np.stack(v) for k, v in anno.items()} anno = {k: np.stack(v) for k, v in anno.items()}
annos.append(anno) annos.append(anno)
if out: if submission_prefix is not None:
cur_det_file = result_dir + '/%06d.txt' % sample_idx curr_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f: with open(curr_file, 'w') as f:
bbox = anno['bbox'] bbox = anno['bbox']
loc = anno['location'] loc = anno['location']
dims = anno['dimensions'] # lhw -> hwl dims = anno['dimensions'] # lhw -> hwl
...@@ -386,9 +425,9 @@ class KittiDataset(torch_data.Dataset): ...@@ -386,9 +425,9 @@ class KittiDataset(torch_data.Dataset):
det_annos += annos det_annos += annos
if out: if pklfile_prefix is not None:
if not out.endswith(('.pkl', '.pickle')): if not pklfile_prefix.endswith(('.pkl', '.pickle')):
out = '{}.pkl'.format(out) out = f'{pklfile_prefix}.pkl'
mmcv.dump(det_annos, out) mmcv.dump(det_annos, out)
print('Result is saved to %s' % out) print('Result is saved to %s' % out)
...@@ -398,7 +437,8 @@ class KittiDataset(torch_data.Dataset): ...@@ -398,7 +437,8 @@ class KittiDataset(torch_data.Dataset):
net_outputs, net_outputs,
class_names, class_names,
sample_ids, sample_ids,
out=None): pklfile_prefix=None,
submission_prefix=None):
"""Convert results to kitti format for evaluation and test submission """Convert results to kitti format for evaluation and test submission
Args: Args:
...@@ -406,6 +446,8 @@ class KittiDataset(torch_data.Dataset): ...@@ -406,6 +446,8 @@ class KittiDataset(torch_data.Dataset):
class_nanes (List[String]): A list of class names class_nanes (List[String]): A list of class names
sample_idx (List[Int]): A list of samples' index, sample_idx (List[Int]): A list of samples' index,
should have the same length as net_outputs. should have the same length as net_outputs.
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Return: Return:
List([dict]): A list of dict have the kitti format List([dict]): A list of dict have the kitti format
...@@ -469,17 +511,20 @@ class KittiDataset(torch_data.Dataset): ...@@ -469,17 +511,20 @@ class KittiDataset(torch_data.Dataset):
[sample_idx] * num_example, dtype=np.int64) [sample_idx] * num_example, dtype=np.int64)
det_annos += annos det_annos += annos
if out: if pklfile_prefix is not None:
# save file in pkl format
pklfile_path = (
pklfile_prefix[:-4] if pklfile_prefix.endswith(
('.pkl', '.pickle')) else pklfile_prefix)
mmcv.dump(det_annos, pklfile_path)
if submission_prefix is not None:
# save file in submission format # save file in submission format
output_dir = out[:-4] if out.endswith(('.pkl', '.pickle')) else out mmcv.mkdir_or_exist(submission_prefix)
result_dir = output_dir + '/data' print(f'Saving KITTI submission to {submission_prefix}')
mmcv.mkdir_or_exist(result_dir)
out = '{}.pkl'.format(result_dir)
mmcv.dump(det_annos, out)
print('Result is saved to {}'.format(out))
for i, anno in enumerate(det_annos): for i, anno in enumerate(det_annos):
sample_idx = sample_ids[i] sample_idx = sample_ids[i]
cur_det_file = result_dir + '/%06d.txt' % sample_idx cur_det_file = f'{submission_prefix}/{sample_idx:06d}.txt'
with open(cur_det_file, 'w') as f: with open(cur_det_file, 'w') as f:
bbox = anno['bbox'] bbox = anno['bbox']
loc = anno['location'] loc = anno['location']
...@@ -497,7 +542,7 @@ class KittiDataset(torch_data.Dataset): ...@@ -497,7 +542,7 @@ class KittiDataset(torch_data.Dataset):
anno['score'][idx]), anno['score'][idx]),
file=f, file=f,
) )
print('Result is saved to {}'.format(result_dir)) print('Result is saved to {}'.format(submission_prefix))
return det_annos return det_annos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment