Commit cbb549aa authored by liyinhao's avatar liyinhao
Browse files

change indoor eval

parent 729f65c9
......@@ -4,45 +4,50 @@ import torch
from mmdet3d.core.bbox.iou_calculators.iou3d_calculator import bbox_overlaps_3d
def voc_ap(rec, prec, use_07_metric=False):
""" Voc AP
Compute VOC AP given precision and recall.
def average_precision(recalls, precisions, mode='area'):
"""Calculate average precision (for single or multiple scales).
Args:
rec (ndarray): Recall.
prec (ndarray): Precision.
use_07_metric (bool): Whether to use 07 metric.
Default: False.
recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
mode (str): 'area' or '11points', 'area' means calculating the area
under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1]
Returns:
ap (float): VOC AP.
float or ndarray: calculated average precision
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
no_scale = False
if recalls.ndim == 1:
no_scale = True
recalls = recalls[np.newaxis, :]
precisions = precisions[np.newaxis, :]
assert recalls.shape == precisions.shape and recalls.ndim == 2
num_scales = recalls.shape[0]
ap = np.zeros(num_scales, dtype=np.float32)
if mode == 'area':
zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
ones = np.ones((num_scales, 1), dtype=recalls.dtype)
mrec = np.hstack((zeros, recalls, ones))
mpre = np.hstack((zeros, precisions, zeros))
for i in range(mpre.shape[1] - 1, 0, -1):
mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
for i in range(num_scales):
ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
ap[i] = np.sum(
(mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
elif mode == '11points':
for i in range(num_scales):
for thr in np.arange(0, 1 + 1e-3, 0.1):
precs = precisions[i, recalls[i, :] >= thr]
prec = precs.max() if precs.size > 0 else 0
ap[i] += prec
ap /= 11
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
raise ValueError(
'Unrecognized mode, only "area" and "11points" are supported')
if no_scale:
ap = ap[0]
return ap
......@@ -65,18 +70,19 @@ def get_iou_gpu(bb1, bb2):
return iou3d.cpu().numpy()
def eval_det_cls(pred, gt, ovthresh=None, use_07_metric=False):
""" Generic functions to compute precision/recall for object detection
def eval_det_cls(pred, gt, ovthresh=None):
"""Generic functions to compute precision/recall for object detection
for a single class.
Input:
pred: map of {img_id: [(bbox, score)]} where bbox is numpy array
gt: map of {img_id: [bbox]}
ovthresh: a list, iou threshold
use_07_metric: bool, if True use VOC07 11 point method
Output:
rec: numpy array of length nd
prec: numpy array of length nd
ap: scalar, average precision
Args:
pred (dict): map of {img_id: [(bbox, score)]} where bbox is numpy array
gt (dict): map of {img_id: [bbox]}
ovthresh (List[float]): a list, iou threshold
Return:
rec (ndarray): numpy array of length nd
prec (ndarray): numpy array of length nd
ap (float): scalar, average precision
"""
# construct gt objects
......@@ -164,23 +170,20 @@ def eval_det_cls(pred, gt, ovthresh=None, use_07_metric=False):
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
ap = average_precision(rec, prec)
ret.append((rec, prec, ap))
return ret
def eval_det_cls_wrapper(arguments):
pred, gt, ovthresh, use_07_metric = arguments
ret = eval_det_cls(pred, gt, ovthresh, use_07_metric)
pred, gt, ovthresh = arguments
ret = eval_det_cls(pred, gt, ovthresh)
return ret
def eval_det_multiprocessing(pred_all,
gt_all,
ovthresh=None,
use_07_metric=False):
"""Evaluate Detection Multiprocessing.
def eval_map_rec(det_infos, gt_infos, ovthresh=None):
"""Evaluate mAP and Recall.
Generic functions to compute precision/recall for object detection
for multiple classes.
......@@ -190,8 +193,6 @@ def eval_det_multiprocessing(pred_all,
gt_all (dict): map of {img_id: [(classname, bbox)]}.
ovthresh (List[float]): iou threshold.
Default: None.
use_07_metric (bool): if true use VOC07 11 point method.
Default: False.
get_iou_func (func): The function to get iou.
Default: get_iou_gpu.
......@@ -200,29 +201,46 @@ def eval_det_multiprocessing(pred_all,
prec (dict): {classname: prec_all}.
ap (dict): {classname: scalar}.
"""
pred_all = {}
gt_all = {}
scan_cnt = 0
for batch_pred_map_cls in det_infos:
for i in range(len(batch_pred_map_cls)):
pred_all[scan_cnt] = batch_pred_map_cls[i]
scan_cnt += 1
# cacge gt infos
scan_cnt = 0
for gt_info in gt_infos:
cur_gt = list()
for n in range(gt_info['gt_num']):
cur_gt.append(
(gt_info['class'][n], gt_info['gt_boxes_upright_depth'][n]))
gt_all[scan_cnt] = cur_gt
scan_cnt += 1
pred = {} # map {classname: pred}
gt = {} # map {classname: gt}
for img_id in pred_all.keys():
for classname, bbox, score in pred_all[img_id]:
if classname not in pred:
pred[classname] = {}
if img_id not in pred[classname]:
pred[classname][img_id] = []
if classname not in gt:
gt[classname] = {}
if img_id not in gt[classname]:
gt[classname][img_id] = []
pred[classname][img_id].append((bbox, score))
for label, bbox, score in pred_all[img_id]:
if label not in pred:
pred[int(label)] = {}
if img_id not in pred[label]:
pred[int(label)][img_id] = []
if label not in gt:
gt[int(label)] = {}
if img_id not in gt[label]:
gt[int(label)][img_id] = []
pred[int(label)][img_id].append((bbox, score))
for img_id in gt_all.keys():
for classname, bbox in gt_all[img_id]:
if classname not in gt:
gt[classname] = {}
if img_id not in gt[classname]:
gt[classname][img_id] = []
gt[classname][img_id].append(bbox)
for label, bbox in gt_all[img_id]:
if label not in gt:
gt[label] = {}
if img_id not in gt[label]:
gt[label][img_id] = []
gt[label][img_id].append(bbox)
ret_values = []
args = [(pred[classname], gt[classname], ovthresh, use_07_metric)
args = [(pred[classname], gt[classname], ovthresh)
for classname in gt.keys() if classname in pred]
rec = [{} for i in ovthresh]
prec = [{} for i in ovthresh]
......@@ -230,101 +248,19 @@ def eval_det_multiprocessing(pred_all,
for arg in args:
ret_values.append(eval_det_cls_wrapper(arg))
for i, classname in enumerate(gt.keys()):
for i, label in enumerate(gt.keys()):
for iou_idx, thresh in enumerate(ovthresh):
if classname in pred:
rec[iou_idx][classname], prec[iou_idx][classname], ap[iou_idx][
classname] = ret_values[i][iou_idx]
if label in pred:
rec[iou_idx][label], prec[iou_idx][label], ap[iou_idx][
label] = ret_values[i][iou_idx]
else:
rec[iou_idx][classname] = 0
prec[iou_idx][classname] = 0
ap[iou_idx][classname] = 0
# print(classname, ap[classname])
rec[iou_idx][label] = 0
prec[iou_idx][label] = 0
ap[iou_idx][label] = 0
return rec, prec, ap
class APCalculator(object):
"""AP Calculator.
Calculating Average Precision.
Args:
ap_iou_thresh (List[float]): a list,
which contains float between 0 and 1.0
IoU threshold to judge whether a prediction is positive.
class2type_map (dict): {class_int:class_name}.
"""
def __init__(self, ap_iou_thresh=None, class2type_map=None):
self.ap_iou_thresh = ap_iou_thresh
self.class2type_map = class2type_map
self.reset()
def step(self, det_infos, gt_infos):
""" Step.
Accumulate one batch of prediction and groundtruth.
Args:
batch_pred_map_cls (List[List]): a list of lists
[[(pred_cls, pred_box_params, score),...],...].
batch_gt_map_cls (List[List]): a list of lists
[[(gt_cls, gt_box_params),...],...].
"""
# cache pred infos
for batch_pred_map_cls in det_infos:
for i in range(len(batch_pred_map_cls)):
self.pred_map_cls[self.scan_cnt] = batch_pred_map_cls[i]
self.scan_cnt += 1
# cacge gt infos
self.scan_cnt = 0
for gt_info in gt_infos:
cur_gt = list()
for n in range(gt_info['gt_num']):
cur_gt.append((gt_info['class'][n],
gt_info['gt_boxes_upright_depth'][n]))
self.gt_map_cls[self.scan_cnt] = cur_gt
self.scan_cnt += 1
def compute_metrics(self):
recs, precs, aps = eval_det_multiprocessing(
self.pred_map_cls, self.gt_map_cls, ovthresh=self.ap_iou_thresh)
ret = []
for i, iou_thresh in enumerate(self.ap_iou_thresh):
ret_dict = {}
rec, _, ap = recs[i], precs[i], aps[i]
for key in sorted(ap.keys()):
clsname = self.class2type_map[
key] if self.class2type_map else str(key)
ret_dict[f'{clsname}_AP_{int(iou_thresh * 100)}'] = ap[key]
ret_dict[f'mAP_{int(iou_thresh * 100)}'] = np.mean(
list(ap.values()))
rec_list = []
for key in sorted(ap.keys()):
clsname = self.class2type_map[
key] if self.class2type_map else str(key)
try:
ret_dict[
f'{clsname}_recall_{int(iou_thresh * 100)}'] = rec[
key][-1]
rec_list.append(rec[key][-1])
except TypeError:
ret_dict[f'{clsname}_recall_{int(iou_thresh * 100)}'] = 0
rec_list.append(0)
ret_dict[f'AR_{int(iou_thresh * 100)}'] = np.mean(rec_list)
ret.append(ret_dict)
return ret
def reset(self):
self.gt_map_cls = {} # {scan_id: [(classname, bbox)]}
self.pred_map_cls = {} # {scan_id: [(classname, bbox, score)]}
self.scan_cnt = 0
def boxes3d_depth_to_lidar(boxes3d, mid_to_bottom=True):
""" Boxes3d Depth to Lidar.
......@@ -345,7 +281,7 @@ def boxes3d_depth_to_lidar(boxes3d, mid_to_bottom=True):
return boxes3d_lidar
def indoor_eval(gt_annos, dt_annos, metric, class2type):
def indoor_eval(gt_annos, dt_annos, metric, label2cat):
"""Scannet Evaluation.
Evaluate the result of the detection.
......@@ -354,11 +290,10 @@ def indoor_eval(gt_annos, dt_annos, metric, class2type):
gt_annos (List): GT annotations.
dt_annos (List): Detection annotations.
metric (List[float]): AP IoU thresholds.
class2type (dict): {class: type}.
label2cat (dict): {label: cat}.
Return:
result_str (str): Result string.
metrics_dict (dict): Result.
ret_dict (dict): Dict of results.
"""
for gt_anno in gt_annos:
......@@ -369,15 +304,20 @@ def indoor_eval(gt_annos, dt_annos, metric, class2type):
if gt_anno['gt_boxes_upright_depth'].shape[-1] == 6:
gt_anno['gt_boxes_upright_depth'] = np.pad(
bbox_lidar_bottom, ((0, 0), (0, 1)), 'constant')
ap_calculator = APCalculator(metric, class2type)
ap_calculator.step(dt_annos, gt_annos)
result_str = str()
result_str += 'mAP'
metrics_dict = {}
metrics = ap_calculator.compute_metrics()
rec, prec, ap = eval_map_rec(dt_annos, gt_annos, metric)
ret_dict = {}
for i, iou_thresh in enumerate(metric):
metrics_tmp = metrics[i]
metrics_dict.update(metrics_tmp)
metric_result = metrics_dict[f'mAP_{int(iou_thresh * 100)}']
result_str += f'({iou_thresh:.2f}:{metric_result}'
return result_str, metrics_dict
for label in ap[i].keys():
ret_dict[f'{label2cat[label]}_AP_{int(iou_thresh * 100)}'] = ap[i][
label]
ret_dict[f'mAP_{int(iou_thresh * 100)}'] = sum(ap[i].values()) / len(
ap[i])
for label in rec[i].keys():
ret_dict[f'{label2cat[label]}_rec_{int(iou_thresh * 100)}'] = rec[
i][label]
ret_dict[f'mAR_{int(iou_thresh * 100)}'] = sum(rec[i].values()) / len(
rec[i])
return ret_dict
......@@ -133,9 +133,8 @@ class IndoorBaseDataset(torch_data.Dataset):
from mmdet3d.core.evaluation import indoor_eval
assert len(metric) > 0
gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]
ap_result_str, ap_dict = indoor_eval(gt_annos, results, metric,
self.label2cat)
return ap_dict
ret_dict = indoor_eval(gt_annos, results, metric, self.label2cat)
return ret_dict
def __len__(self):
return len(self.infos)
......@@ -108,11 +108,11 @@ def test_evaluate():
pred_boxes['scores'] = torch.Tensor([0.5, 1.0, 1.0, 1.0, 1.0]).cuda()
results.append([pred_boxes])
metric = [0.25, 0.5]
ap_dict = scannet_dataset.evaluate(results, metric)
table_average_precision_25 = ap_dict['table_AP_25']
window_average_precision_25 = ap_dict['window_AP_25']
counter_average_precision_25 = ap_dict['counter_AP_25']
curtain_average_precision_25 = ap_dict['curtain_AP_25']
ret_dict = scannet_dataset.evaluate(results, metric)
table_average_precision_25 = ret_dict['table_AP_25']
window_average_precision_25 = ret_dict['window_AP_25']
counter_average_precision_25 = ret_dict['counter_AP_25']
curtain_average_precision_25 = ret_dict['curtain_AP_25']
assert abs(table_average_precision_25 - 0.3333) < 0.01
assert abs(window_average_precision_25 - 1) < 0.01
assert abs(counter_average_precision_25 - 1) < 0.01
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment