Commit df14830a authored by Kai Chen's avatar Kai Chen
Browse files

add a choice 'proposal_fast' to eval script

parent d0fb2a8d
from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes)
from .coco_utils import coco_eval, results2json
from .eval_hooks import DistEvalHook, DistEvalRecallHook, CocoDistEvalmAPHook
from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
plot_iou_recall)
......@@ -10,8 +11,8 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
'results2json', 'DistEvalHook', 'DistEvalRecallHook',
'CocoDistEvalmAPHook', 'average_precision', 'eval_map',
'print_map_summary', 'eval_recalls', 'print_recall_summary',
'fast_eval_recall', 'results2json', 'DistEvalHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
]
......@@ -3,17 +3,28 @@ import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from .recall import eval_recalls
def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
assert result_file.endswith('.json')
for res_type in result_types:
assert res_type in ['proposal', 'bbox', 'segm', 'keypoints']
assert res_type in [
'proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'
]
if mmcv.is_str(coco):
coco = COCO(coco)
assert isinstance(coco, COCO)
if res_type == 'proposal_fast':
ar = fast_eval_recall(result_file, coco, max_dets)
for i, num in enumerate(max_dets):
print('AR@{}\t= {:.4f}'.format(num, ar[i]))
return
assert result_file.endswith('.json')
coco_dets = coco.loadRes(result_file)
img_ids = coco.getImgIds()
for res_type in result_types:
iou_type = 'bbox' if res_type == 'proposal' else res_type
......@@ -27,6 +38,43 @@ def coco_eval(result_file, result_types, coco, max_dets=(100, 300, 1000)):
cocoEval.summarize()
def fast_eval_recall(results,
coco,
max_dets,
iou_thrs=np.arange(0.5, 0.96, 0.05)):
if mmcv.is_str(results):
assert results.endswith('.pkl')
results = mmcv.load(results)
elif not isinstance(results, list):
raise TypeError(
'results must be a list of numpy arrays or a filename, not {}'.
format(type(results)))
gt_bboxes = []
img_ids = coco.getImgIds()
for i in range(len(img_ids)):
ann_ids = coco.getAnnIds(imgIds=img_ids[i])
ann_info = coco.loadAnns(ann_ids)
if len(ann_info) == 0:
gt_bboxes.append(np.zeros((0, 4)))
continue
bboxes = []
for ann in ann_info:
if ann.get('ignore', False) or ann['iscrowd']:
continue
x1, y1, w, h = ann['bbox']
bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1])
bboxes = np.array(bboxes, dtype=np.float32)
if bboxes.shape[0] == 0:
bboxes = np.zeros((0, 4))
gt_bboxes.append(bboxes)
recalls = eval_recalls(
gt_bboxes, results, max_dets, iou_thrs, print_summary=False)
ar = recalls.mean(axis=1)
return ar
def xyxy2xywh(bbox):
_bbox = bbox.tolist()
return [
......
......@@ -10,7 +10,7 @@ from mmcv.torchpack import Hook, obj_from_dict
from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
from .coco_utils import results2json
from .coco_utils import results2json, fast_eval_recall
from .recall import eval_recalls
from ..parallel import scatter
from mmdet import datasets
......@@ -100,45 +100,21 @@ class DistEvalHook(Hook):
raise NotImplementedError
class DistEvalRecallHook(DistEvalHook):
class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self,
dataset,
proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)):
super(DistEvalRecallHook, self).__init__(dataset)
super(CocoDistEvalRecallHook, self).__init__(dataset)
self.proposal_nums = np.array(proposal_nums, dtype=np.int32)
self.iou_thrs = np.array(iou_thrs, dtype=np.float32)
def evaluate(self, runner, results):
# the official coco evaluation is too slow, here we use our own
# implementation instead, which may get slightly different results
gt_bboxes = []
for i in range(len(self.dataset)):
img_id = self.dataset.img_ids[i]
ann_ids = self.dataset.coco.getAnnIds(imgIds=img_id)
ann_info = self.dataset.coco.loadAnns(ann_ids)
if len(ann_info) == 0:
gt_bboxes.append(np.zeros((0, 4)))
continue
bboxes = []
for ann in ann_info:
if ann.get('ignore', False) or ann['iscrowd']:
continue
x1, y1, w, h = ann['bbox']
bboxes.append([x1, y1, x1 + w - 1, y1 + h - 1])
bboxes = np.array(bboxes, dtype=np.float32)
if bboxes.shape[0] == 0:
bboxes = np.zeros((0, 4))
gt_bboxes.append(bboxes)
recalls = eval_recalls(
gt_bboxes,
results,
self.proposal_nums,
self.iou_thrs,
print_summary=False)
ar = recalls.mean(axis=1)
ar = fast_eval_recall(results, self.dataset.coco, self.proposal_nums,
self.iou_thrs)
for i, num in enumerate(self.proposal_nums):
runner.log_buffer.output['AR@{}'.format(num)] = ar[i]
runner.log_buffer.ready = True
......
......@@ -8,13 +8,18 @@ def main():
parser.add_argument('result', help='result file path')
parser.add_argument('--ann', help='annotation file path')
parser.add_argument(
'--types', type=str, nargs='+', default=['bbox'], help='result types')
'--types',
type=str,
nargs='+',
choices=['proposal_fast', 'proposal', 'bbox', 'segm', 'keypoint'],
default=['bbox'],
help='result types')
parser.add_argument(
'--max-dets',
type=int,
nargs='+',
default=[100, 300, 1000],
help='result types')
help='proposal numbers, only used for recall evaluation')
args = parser.parse_args()
coco_eval(args.result, args.types, args.ann, args.max_dets)
......
......@@ -11,7 +11,7 @@ from mmcv.torchpack import Runner, obj_from_dict
from mmdet import datasets
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
MMDataParallel, MMDistributedDataParallel,
DistEvalRecallHook, CocoDistEvalmAPHook)
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets.loader import build_dataloader
from mmdet.models import build_detector, RPN
......@@ -127,7 +127,7 @@ def main():
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if isinstance(model.module, RPN):
runner.register_hook(DistEvalRecallHook(cfg.data.val))
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment