Commit 60d5bc54 authored by Kai Chen's avatar Kai Chen
Browse files

add eval hooks for VOC dataset

parent 9d38a278
......@@ -6,8 +6,8 @@ import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
from .env import get_root_logger
......@@ -81,9 +81,13 @@ def _dist_train(model, dataset, cfg, validate=False):
# register eval hooks
if validate:
if isinstance(model.module, RPN):
# TODO: implement recall hooks for other datasets
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
else:
if cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
else:
runner.register_hook(DistEvalmAPHook(cfg.data.val))
if cfg.resume_from:
runner.resume(cfg.resume_from)
......
......@@ -2,7 +2,7 @@ from .class_names import (voc_classes, imagenet_det_classes,
imagenet_vid_classes, coco_classes, dataset_aliases,
get_classes)
from .coco_utils import coco_eval, fast_eval_recall, results2json
from .eval_hooks import (DistEvalHook, CocoDistEvalRecallHook,
from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
......@@ -11,7 +11,7 @@ from .recall import (eval_recalls, print_recall_summary, plot_num_recall,
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval',
'fast_eval_recall', 'results2json', 'DistEvalHook',
'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook',
'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision',
'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall'
......
......@@ -12,6 +12,7 @@ from pycocotools.cocoeval import COCOeval
from torch.utils.data import Dataset
from .coco_utils import results2json, fast_eval_recall
from .mean_ap import eval_map
from mmdet import datasets
......@@ -102,6 +103,44 @@ class DistEvalHook(Hook):
raise NotImplementedError
class DistEvalmAPHook(DistEvalHook):
def evaluate(self, runner, results):
gt_bboxes = []
gt_labels = []
gt_ignore = [] if self.dataset.with_crowd else None
for i in range(len(self.dataset)):
ann = self.dataset.get_ann_info(i)
bboxes = ann['bboxes']
labels = ann['labels']
if gt_ignore is not None:
ignore = np.concatenate([
np.zeros(bboxes.shape[0], dtype=np.bool),
np.ones(ann['bboxes_ignore'].shape[0], dtype=np.bool)
])
gt_ignore.append(ignore)
bboxes = np.vstack([bboxes, ann['bboxes_ignore']])
labels = np.concatenate([labels, ann['labels_ignore']])
gt_bboxes.append(bboxes)
gt_labels.append(labels)
# If the dataset is VOC2007, then use 11 points mAP evaluation.
if hasattr(self.dataset, 'year') and self.dataset.year == 2007:
ds_name = 'voc07'
else:
ds_name = self.dataset.CLASSES
mean_ap, eval_results = eval_map(
results,
gt_bboxes,
gt_labels,
gt_ignore=gt_ignore,
scale_ranges=None,
iou_thr=0.5,
dataset=ds_name,
print_summary=True)
runner.log_buffer.output['mAP'] = mean_ap
runner.log_buffer.ready = True
class CocoDistEvalRecallHook(DistEvalHook):
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment