Unverified Commit 47bab544 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Refactoring for mAP evaluation (#1889)

* refactoring mAP evaluation

* fix logger type
parent 4472c661
...@@ -76,25 +76,9 @@ class DistEvalHook(Hook): ...@@ -76,25 +76,9 @@ class DistEvalHook(Hook):
class DistEvalmAPHook(DistEvalHook): class DistEvalmAPHook(DistEvalHook):
def evaluate(self, runner, results): def evaluate(self, runner, results):
gt_bboxes = [] annotations = [
gt_labels = [] self.dataset.get_ann_info(i) for i in range(len(self.dataset))
gt_ignore = [] ]
for i in range(len(self.dataset)):
ann = self.dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if 'bboxes_ignore' in ann:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
if not gt_ignore:
gt_ignore = None
# If the dataset is VOC2007, then use 11 points mAP evaluation. # If the dataset is VOC2007, then use 11 points mAP evaluation.
if hasattr(self.dataset, 'year') and self.dataset.year == 2007: if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
ds_name = 'voc07' ds_name = 'voc07'
...@@ -102,13 +86,11 @@ class DistEvalmAPHook(DistEvalHook): ...@@ -102,13 +86,11 @@ class DistEvalmAPHook(DistEvalHook):
ds_name = self.dataset.CLASSES ds_name = self.dataset.CLASSES
mean_ap, eval_results = eval_map( mean_ap, eval_results = eval_map(
results, results,
gt_bboxes, annotations,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None, scale_ranges=None,
iou_thr=0.5, iou_thr=0.5,
dataset=ds_name, dataset=ds_name,
print_summary=True) logger=runner.logger)
runner.log_buffer.output['mAP'] = mean_ap runner.log_buffer.output['mAP'] = mean_ap
runner.log_buffer.ready = True runner.log_buffer.ready = True
......
import logging
from multiprocessing import Pool
import mmcv import mmcv
import numpy as np import numpy as np
from terminaltables import AsciiTable from terminaltables import AsciiTable
...@@ -55,21 +58,33 @@ def average_precision(recalls, precisions, mode='area'): ...@@ -55,21 +58,33 @@ def average_precision(recalls, precisions, mode='area'):
def tpfp_imagenet(det_bboxes, def tpfp_imagenet(det_bboxes,
gt_bboxes, gt_bboxes,
gt_ignore, gt_bboxes_ignore=None,
default_iou_thr, default_iou_thr=0.5,
area_ranges=None): area_ranges=None):
"""Check if detected bboxes are true positive or false positive. """Check if detected bboxes are true positive or false positive.
Args: Args:
det_bbox (ndarray): the detected bbox det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
gt_bboxes (ndarray): ground truth bboxes of this image gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
gt_ignore (ndarray): indicate if gts are ignored for evaluation or not gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
default_iou_thr (float): the iou thresholds for medium and large bboxes of shape (k, 4). Default: None
area_ranges (list or None): gt bbox area ranges default_iou_thr (float): IoU threshold to be considered as matched for
medium and large bboxes (small ones have special rules).
Default: 0.5.
area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
in the format [(min1, max1), (min2, max2), ...]. Default: None.
Returns: Returns:
tuple: two arrays (tp, fp) whose elements are 0 and 1 tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
each array is (num_scales, m).
""" """
# an indicator of ignored gts
gt_ignore_inds = np.concatenate(
(np.zeros(gt_bboxes.shape[0], dtype=np.bool),
np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
# stack gt_bboxes and gt_bboxes_ignore for convenience
gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
num_dets = det_bboxes.shape[0] num_dets = det_bboxes.shape[0]
num_gts = gt_bboxes.shape[0] num_gts = gt_bboxes.shape[0]
if area_ranges is None: if area_ranges is None:
...@@ -99,7 +114,7 @@ def tpfp_imagenet(det_bboxes, ...@@ -99,7 +114,7 @@ def tpfp_imagenet(det_bboxes,
gt_covered = np.zeros(num_gts, dtype=bool) gt_covered = np.zeros(num_gts, dtype=bool)
# if no area range is specified, gt_area_ignore is all False # if no area range is specified, gt_area_ignore is all False
if min_area is None: if min_area is None:
gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool) gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
else: else:
gt_areas = gt_w * gt_h gt_areas = gt_w * gt_h
gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area) gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
...@@ -122,7 +137,8 @@ def tpfp_imagenet(det_bboxes, ...@@ -122,7 +137,8 @@ def tpfp_imagenet(det_bboxes,
# 4. it matches no gt but is beyond area range, tp = 0, fp = 0 # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
if matched_gt >= 0: if matched_gt >= 0:
gt_covered[matched_gt] = 1 gt_covered[matched_gt] = 1
if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]): if not (gt_ignore_inds[matched_gt]
or gt_area_ignore[matched_gt]):
tp[k, i] = 1 tp[k, i] = 1
elif min_area is None: elif min_area is None:
fp[k, i] = 1 fp[k, i] = 1
...@@ -134,18 +150,34 @@ def tpfp_imagenet(det_bboxes, ...@@ -134,18 +150,34 @@ def tpfp_imagenet(det_bboxes,
return tp, fp return tp, fp
def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): def tpfp_default(det_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
iou_thr=0.5,
area_ranges=None):
"""Check if detected bboxes are true positive or false positive. """Check if detected bboxes are true positive or false positive.
Args: Args:
det_bbox (ndarray): the detected bbox det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
gt_bboxes (ndarray): ground truth bboxes of this image gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
gt_ignore (ndarray): indicate if gts are ignored for evaluation or not gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
iou_thr (float): the iou thresholds of shape (k, 4). Default: None
iou_thr (float): IoU threshold to be considered as matched.
Default: 0.5.
area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
in the format [(min1, max1), (min2, max2), ...]. Default: None.
Returns: Returns:
tuple: (tp, fp), two arrays whose elements are 0 and 1 tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
each array is (num_scales, m).
""" """
# an indicator of ignored gts
gt_ignore_inds = np.concatenate(
(np.zeros(gt_bboxes.shape[0], dtype=np.bool),
np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
# stack gt_bboxes and gt_bboxes_ignore for convenience
gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
num_dets = det_bboxes.shape[0] num_dets = det_bboxes.shape[0]
num_gts = gt_bboxes.shape[0] num_gts = gt_bboxes.shape[0]
if area_ranges is None: if area_ranges is None:
...@@ -155,6 +187,7 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): ...@@ -155,6 +187,7 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
# a certain scale # a certain scale
tp = np.zeros((num_scales, num_dets), dtype=np.float32) tp = np.zeros((num_scales, num_dets), dtype=np.float32)
fp = np.zeros((num_scales, num_dets), dtype=np.float32) fp = np.zeros((num_scales, num_dets), dtype=np.float32)
# if there is no gt bboxes in this image, then all det bboxes # if there is no gt bboxes in this image, then all det bboxes
# within area range are false positives # within area range are false positives
if gt_bboxes.shape[0] == 0: if gt_bboxes.shape[0] == 0:
...@@ -166,15 +199,19 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): ...@@ -166,15 +199,19 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
for i, (min_area, max_area) in enumerate(area_ranges): for i, (min_area, max_area) in enumerate(area_ranges):
fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1 fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
return tp, fp return tp, fp
ious = bbox_overlaps(det_bboxes, gt_bboxes) ious = bbox_overlaps(det_bboxes, gt_bboxes)
# for each det, the max iou with all gts
ious_max = ious.max(axis=1) ious_max = ious.max(axis=1)
# for each det, which gt overlaps most with it
ious_argmax = ious.argmax(axis=1) ious_argmax = ious.argmax(axis=1)
# sort all dets in descending order by scores
sort_inds = np.argsort(-det_bboxes[:, -1]) sort_inds = np.argsort(-det_bboxes[:, -1])
for k, (min_area, max_area) in enumerate(area_ranges): for k, (min_area, max_area) in enumerate(area_ranges):
gt_covered = np.zeros(num_gts, dtype=bool) gt_covered = np.zeros(num_gts, dtype=bool)
# if no area range is specified, gt_area_ignore is all False # if no area range is specified, gt_area_ignore is all False
if min_area is None: if min_area is None:
gt_area_ignore = np.zeros_like(gt_ignore, dtype=bool) gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
else: else:
gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * ( gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1) gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
...@@ -182,7 +219,8 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): ...@@ -182,7 +219,8 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
for i in sort_inds: for i in sort_inds:
if ious_max[i] >= iou_thr: if ious_max[i] >= iou_thr:
matched_gt = ious_argmax[i] matched_gt = ious_argmax[i]
if not (gt_ignore[matched_gt] or gt_area_ignore[matched_gt]): if not (gt_ignore_inds[matched_gt]
or gt_area_ignore[matched_gt]):
if not gt_covered[matched_gt]: if not gt_covered[matched_gt]:
gt_covered[matched_gt] = True gt_covered[matched_gt] = True
tp[k, i] = 1 tp[k, i] = 1
...@@ -199,88 +237,109 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None): ...@@ -199,88 +237,109 @@ def tpfp_default(det_bboxes, gt_bboxes, gt_ignore, iou_thr, area_ranges=None):
return tp, fp return tp, fp
def get_cls_results(det_results, gt_bboxes, gt_labels, gt_ignore, class_id): def get_cls_results(det_results, annotations, class_id):
"""Get det results and gt information of a certain class.""" """Get det results and gt information of a certain class.
cls_dets = [det[class_id]
for det in det_results] # det bboxes of this class Args:
cls_gts = [] # gt bboxes of this class det_results (list[list]): Same as `eval_map()`.
cls_gt_ignore = [] annotations (list[dict]): Same as `eval_map()`.
for j in range(len(gt_bboxes)):
gt_bbox = gt_bboxes[j] Returns:
cls_inds = (gt_labels[j] == class_id + 1) tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
cls_gt = gt_bbox[cls_inds, :] if gt_bbox.shape[0] > 0 else gt_bbox """
cls_gts.append(cls_gt) cls_dets = [img_res[class_id] for img_res in det_results]
if gt_ignore is None: cls_gts = []
cls_gt_ignore.append(np.zeros(cls_gt.shape[0], dtype=np.int32)) cls_gts_ignore = []
for ann in annotations:
gt_inds = ann['labels'] == (class_id + 1)
cls_gts.append(ann['bboxes'][gt_inds, :])
if ann.get('labels_ignore', None) is not None:
ignore_inds = ann['labels_ignore'] == (class_id + 1)
cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
else: else:
cls_gt_ignore.append(gt_ignore[j][cls_inds]) cls_gts_ignore.append(np.array((0, 4), dtype=np.float32))
return cls_dets, cls_gts, cls_gt_ignore
return cls_dets, cls_gts, cls_gts_ignore
def eval_map(det_results, def eval_map(det_results,
gt_bboxes, annotations,
gt_labels,
gt_ignore=None,
scale_ranges=None, scale_ranges=None,
iou_thr=0.5, iou_thr=0.5,
dataset=None, dataset=None,
print_summary=True): logger='default',
nproc=4):
"""Evaluate mAP of a dataset. """Evaluate mAP of a dataset.
Args: Args:
det_results (list): a list of list, [[cls1_det, cls2_det, ...], ...] det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
gt_bboxes (list): ground truth bboxes of each image, a list of K*4 The outer list indicates images, and the inner list indicates
array. per-class detected bboxes.
gt_labels (list): ground truth labels of each image, a list of K array annotations (list[dict]): Ground truth annotations where each item of
gt_ignore (list): gt ignore indicators of each image, a list of K array the list indicates an image. Keys of annotations are:
scale_ranges (list, optional): [(min1, max1), (min2, max2), ...] - "bboxes": numpy array of shape (n, 4)
iou_thr (float): IoU threshold - "labels": numpy array of shape (n, )
dataset (None or str or list): dataset name or dataset classes, there - "bboxes_ignore" (optional): numpy array of shape (k, 4)
are minor differences in metrics for different datsets, e.g. - "labels_ignore" (optional): numpy array of shape (k, )
"voc07", "imagenet_det", etc. scale_ranges (list[tuple] | None): Range of scales to be evaluated,
print_summary (bool): whether to print the mAP summary in the format [(min1, max1), (min2, max2), ...]. A range of
(32, 64) means the area range between (32**2, 64**2).
Default: None.
iou_thr (float): IoU threshold to be considered as matched.
Default: 0.5.
dataset (list[str] | str | None): Dataset name or dataset classes,
there are minor differences in metrics for different datsets, e.g.
"voc07", "imagenet_det", etc. Default: None.
logger (logging.Logger | 'print' | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
nproc (int): Processes used for computing TP and FP.
Default: 4.
Returns: Returns:
tuple: (mAP, [dict, dict, ...]) tuple: (mAP, [dict, dict, ...])
""" """
assert len(det_results) == len(gt_bboxes) == len(gt_labels) assert len(det_results) == len(annotations)
if gt_ignore is not None:
assert len(gt_ignore) == len(gt_labels) num_imgs = len(det_results)
for i in range(len(gt_ignore)): num_scales = len(scale_ranges) if scale_ranges is not None else 1
assert len(gt_labels[i]) == len(gt_ignore[i]) num_classes = len(det_results[0]) # positive class num
area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
if scale_ranges is not None else None) if scale_ranges is not None else None)
num_scales = len(scale_ranges) if scale_ranges is not None else 1
pool = Pool(nproc)
eval_results = [] eval_results = []
num_classes = len(det_results[0]) # positive class num
gt_labels = [
label if label.ndim == 1 else label[:, 0] for label in gt_labels
]
for i in range(num_classes): for i in range(num_classes):
# get gt and det bboxes of this class # get gt and det bboxes of this class
cls_dets, cls_gts, cls_gt_ignore = get_cls_results( cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
det_results, gt_bboxes, gt_labels, gt_ignore, i) det_results, annotations, i)
# calculate tp and fp for each image # choose proper function according to datasets to compute tp and fp
tpfp_func = ( if dataset in ['det', 'vid']:
tpfp_imagenet if dataset in ['det', 'vid'] else tpfp_default) tpfp_func = tpfp_imagenet
tpfp = [ else:
tpfp_func(cls_dets[j], cls_gts[j], cls_gt_ignore[j], iou_thr, tpfp_func = tpfp_default
area_ranges) for j in range(len(cls_dets)) # compute tp and fp for each image with multiple processes
] tpfp = pool.starmap(
tpfp_func,
zip(cls_dets, cls_gts, cls_gts_ignore,
[iou_thr for _ in range(num_imgs)],
[area_ranges for _ in range(num_imgs)]))
tp, fp = tuple(zip(*tpfp)) tp, fp = tuple(zip(*tpfp))
# calculate gt number of each scale, gts ignored or beyond scale # calculate gt number of each scale
# are not counted # ignored gts or gts beyond the specific scale are not counted
num_gts = np.zeros(num_scales, dtype=int) num_gts = np.zeros(num_scales, dtype=int)
for j, bbox in enumerate(cls_gts): for j, bbox in enumerate(cls_gts):
if area_ranges is None: if area_ranges is None:
num_gts[0] += np.sum(np.logical_not(cls_gt_ignore[j])) num_gts[0] += bbox.shape[0]
else: else:
gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * ( gt_areas = (bbox[:, 2] - bbox[:, 0] + 1) * (
bbox[:, 3] - bbox[:, 1] + 1) bbox[:, 3] - bbox[:, 1] + 1)
for k, (min_area, max_area) in enumerate(area_ranges): for k, (min_area, max_area) in enumerate(area_ranges):
num_gts[k] += np.sum( num_gts[k] += np.sum((gt_areas >= min_area)
np.logical_not(cls_gt_ignore[j]) & (gt_areas < max_area))
& (gt_areas >= min_area) & (gt_areas < max_area))
# sort all det bboxes by score, also sort tp and fp # sort all det bboxes by score, also sort tp and fp
cls_dets = np.vstack(cls_dets) cls_dets = np.vstack(cls_dets)
num_dets = cls_dets.shape[0] num_dets = cls_dets.shape[0]
...@@ -324,37 +383,60 @@ def eval_map(det_results, ...@@ -324,37 +383,60 @@ def eval_map(det_results,
if cls_result['num_gts'] > 0: if cls_result['num_gts'] > 0:
aps.append(cls_result['ap']) aps.append(cls_result['ap'])
mean_ap = np.array(aps).mean().item() if aps else 0.0 mean_ap = np.array(aps).mean().item() if aps else 0.0
if print_summary: if logger is not None:
print_map_summary(mean_ap, eval_results, dataset, area_ranges) print_map_summary(
mean_ap, eval_results, dataset, area_ranges, logger=logger)
return mean_ap, eval_results return mean_ap, eval_results
def print_map_summary(mean_ap, results, dataset=None, ranges=None): def print_map_summary(mean_ap,
results,
dataset=None,
scale_ranges=None,
logger=None):
"""Print mAP and results of each class. """Print mAP and results of each class.
A table will be printed to show the gts/dets/recall/AP of each class and
the mAP.
Args: Args:
mean_ap(float): calculated from `eval_map` mean_ap (float): Calculated from `eval_map()`.
results(list): calculated from `eval_map` results (list[dict]): Calculated from `eval_map()`.
dataset(None or str or list): dataset name or dataset classes. dataset (list[str] | str | None): Dataset name or dataset classes.
ranges(list or Tuple): ranges of areas scale_ranges (list[tuple] | None): Range of scales to be evaluated.
logger (logging.Logger | 'print' | None): The way to print the mAP
summary. If a Logger is specified, then the summary will be logged
with `logger.info()`; if set to "print", then it will be simply
printed to stdout; if set to None, then no information will be
printed. Default: 'print'.
""" """
num_scales = len(results[0]['ap']) if isinstance(results[0]['ap'],
np.ndarray) else 1 def _print(content):
if ranges is not None: if logger == 'print':
assert len(ranges) == num_scales print(content)
elif isinstance(logger, logging.Logger):
logger.info(content)
if isinstance(results[0]['ap'], np.ndarray):
num_scales = len(results[0]['ap'])
else:
num_scales = 1
if scale_ranges is not None:
assert len(scale_ranges) == num_scales
assert logger is None or logger == 'print' or isinstance(
logger, logging.Logger)
num_classes = len(results) num_classes = len(results)
recalls = np.zeros((num_scales, num_classes), dtype=np.float32) recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
precisions = np.zeros((num_scales, num_classes), dtype=np.float32)
aps = np.zeros((num_scales, num_classes), dtype=np.float32) aps = np.zeros((num_scales, num_classes), dtype=np.float32)
num_gts = np.zeros((num_scales, num_classes), dtype=int) num_gts = np.zeros((num_scales, num_classes), dtype=int)
for i, cls_result in enumerate(results): for i, cls_result in enumerate(results):
if cls_result['recall'].size > 0: if cls_result['recall'].size > 0:
recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
precisions[:, i] = np.array(
cls_result['precision'], ndmin=2)[:, -1]
aps[:, i] = cls_result['ap'] aps[:, i] = cls_result['ap']
num_gts[:, i] = cls_result['num_gts'] num_gts[:, i] = cls_result['num_gts']
...@@ -367,19 +449,19 @@ def print_map_summary(mean_ap, results, dataset=None, ranges=None): ...@@ -367,19 +449,19 @@ def print_map_summary(mean_ap, results, dataset=None, ranges=None):
if not isinstance(mean_ap, list): if not isinstance(mean_ap, list):
mean_ap = [mean_ap] mean_ap = [mean_ap]
header = ['class', 'gts', 'dets', 'recall', 'precision', 'ap']
header = ['class', 'gts', 'dets', 'recall', 'ap']
for i in range(num_scales): for i in range(num_scales):
if ranges is not None: if scale_ranges is not None:
print("Area range ", ranges[i]) _print('Scale range ', scale_ranges[i])
table_data = [header] table_data = [header]
for j in range(num_classes): for j in range(num_classes):
row_data = [ row_data = [
label_names[j], num_gts[i, j], results[j]['num_dets'], label_names[j], num_gts[i, j], results[j]['num_dets'],
'{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(recalls[i, j]), '{:.3f}'.format(aps[i, j])
'{:.3f}'.format(precisions[i, j]), '{:.3f}'.format(aps[i, j])
] ]
table_data.append(row_data) table_data.append(row_data)
table_data.append(['mAP', '', '', '', '', '{:.3f}'.format(mean_ap[i])]) table_data.append(['mAP', '', '', '', '{:.3f}'.format(mean_ap[i])])
table = AsciiTable(table_data) table = AsciiTable(table_data)
table.inner_footing_row_border = True table.inner_footing_row_border = True
print(table.table) _print('\n' + table.table)
...@@ -76,41 +76,21 @@ def coco_eval_with_return(result_files, ...@@ -76,41 +76,21 @@ def coco_eval_with_return(result_files,
def voc_eval_with_return(result_file, def voc_eval_with_return(result_file,
dataset, dataset,
iou_thr=0.5, iou_thr=0.5,
print_summary=True, logger='print',
only_ap=True): only_ap=True):
det_results = mmcv.load(result_file) det_results = mmcv.load(result_file)
gt_bboxes = [] annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
gt_labels = []
gt_ignore = []
for i in range(len(dataset)):
ann = dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if 'bboxes_ignore' in ann:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
if not gt_ignore:
gt_ignore = gt_ignore
if hasattr(dataset, 'year') and dataset.year == 2007: if hasattr(dataset, 'year') and dataset.year == 2007:
dataset_name = 'voc07' dataset_name = 'voc07'
else: else:
dataset_name = dataset.CLASSES dataset_name = dataset.CLASSES
mean_ap, eval_results = eval_map( mean_ap, eval_results = eval_map(
det_results, det_results,
gt_bboxes, annotations,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None, scale_ranges=None,
iou_thr=iou_thr, iou_thr=iou_thr,
dataset=dataset_name, dataset=dataset_name,
print_summary=print_summary) logger=logger)
if only_ap: if only_ap:
eval_results = [{ eval_results = [{
...@@ -411,10 +391,11 @@ def main(): ...@@ -411,10 +391,11 @@ def main():
if eval_type == 'bbox': if eval_type == 'bbox':
test_dataset = mmcv.runner.obj_from_dict( test_dataset = mmcv.runner.obj_from_dict(
cfg.data.test, datasets) cfg.data.test, datasets)
logger = 'print' if args.summaries else None
mean_ap, eval_results = \ mean_ap, eval_results = \
voc_eval_with_return( voc_eval_with_return(
args.out, test_dataset, args.out, test_dataset,
args.iou_thr, args.summaries) args.iou_thr, logger)
aggregated_results[corruption][ aggregated_results[corruption][
corruption_severity] = eval_results corruption_severity] = eval_results
else: else:
......
from argparse import ArgumentParser from argparse import ArgumentParser
import mmcv import mmcv
import numpy as np
from mmdet import datasets from mmdet import datasets
from mmdet.core import eval_map from mmdet.core import eval_map
def voc_eval(result_file, dataset, iou_thr=0.5): def voc_eval(result_file, dataset, iou_thr=0.5, nproc=4):
det_results = mmcv.load(result_file) det_results = mmcv.load(result_file)
gt_bboxes = [] annotations = [dataset.get_ann_info(i) for i in range(len(dataset))]
gt_labels = []
gt_ignore = []
for i in range(len(dataset)):
ann = dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if 'bboxes_ignore' in ann:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
if not gt_ignore:
gt_ignore = None
if hasattr(dataset, 'year') and dataset.year == 2007: if hasattr(dataset, 'year') and dataset.year == 2007:
dataset_name = 'voc07' dataset_name = 'voc07'
else: else:
dataset_name = dataset.CLASSES dataset_name = dataset.CLASSES
eval_map( eval_map(
det_results, det_results,
gt_bboxes, annotations,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None, scale_ranges=None,
iou_thr=iou_thr, iou_thr=iou_thr,
dataset=dataset_name, dataset=dataset_name,
print_summary=True) logger='print',
nproc=nproc)
def main(): def main():
...@@ -52,10 +32,15 @@ def main(): ...@@ -52,10 +32,15 @@ def main():
type=float, type=float,
default=0.5, default=0.5,
help='IoU threshold for evaluation') help='IoU threshold for evaluation')
parser.add_argument(
'--nproc',
type=int,
default=4,
help='Processes to be used for computing mAP')
args = parser.parse_args() args = parser.parse_args()
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
test_dataset = mmcv.runner.obj_from_dict(cfg.data.test, datasets) test_dataset = mmcv.runner.obj_from_dict(cfg.data.test, datasets)
voc_eval(args.result, test_dataset, args.iou_thr) voc_eval(args.result, test_dataset, args.iou_thr, args.nproc)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment