Commit 2bde2485 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Cleanup COCO eval code and enable reporting all metrics with per category metrics.

PiperOrigin-RevId: 341188685
parent 30cc8e74
...@@ -1177,6 +1177,12 @@ def evaluator_options_from_eval_config(eval_config): ...@@ -1177,6 +1177,12 @@ def evaluator_options_from_eval_config(eval_config):
'include_metrics_per_category': ( 'include_metrics_per_category': (
eval_config.include_metrics_per_category) eval_config.include_metrics_per_category)
} }
if (hasattr(eval_config, 'all_metrics_per_category') and
eval_config.all_metrics_per_category):
evaluator_options[eval_metric_fn_key].update({
'all_metrics_per_category': eval_config.all_metrics_per_category
})
# For coco detection eval, if the eval_config proto contains the # For coco detection eval, if the eval_config proto contains the
# "skip_predictions_for_unlabeled_class" field, include this field in # "skip_predictions_for_unlabeled_class" field, include this field in
# evaluator_options. # evaluator_options.
......
...@@ -961,6 +961,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -961,6 +961,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
def __init__(self, categories, def __init__(self, categories,
include_metrics_per_category=False, include_metrics_per_category=False,
all_metrics_per_category=False,
super_categories=None): super_categories=None):
"""Constructor. """Constructor.
...@@ -969,6 +970,10 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -969,6 +970,10 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
'id': (required) an integer id uniquely identifying this category. 'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'. 'name': (required) string representing category name e.g., 'cat', 'dog'.
include_metrics_per_category: If True, include metrics for each category. include_metrics_per_category: If True, include metrics for each category.
all_metrics_per_category: Whether to include all the summary metrics for
each category in per_category_ap. Be careful with setting it to true if
you have more than handful of categories, because it will pollute
your mldash.
super_categories: None or a python dict mapping super-category names super_categories: None or a python dict mapping super-category names
(strings) to lists of categories (corresponding to category names (strings) to lists of categories (corresponding to category names
in the label_map). Metrics are aggregated along these super-categories in the label_map). Metrics are aggregated along these super-categories
...@@ -984,6 +989,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -984,6 +989,7 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
self._annotation_id = 1 self._annotation_id = 1
self._include_metrics_per_category = include_metrics_per_category self._include_metrics_per_category = include_metrics_per_category
self._super_categories = super_categories self._super_categories = super_categories
self._all_metrics_per_category = all_metrics_per_category
def clear(self): def clear(self):
"""Clears the state to prepare for a fresh evaluation.""" """Clears the state to prepare for a fresh evaluation."""
...@@ -1177,7 +1183,8 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator): ...@@ -1177,7 +1183,8 @@ class CocoMaskEvaluator(object_detection_evaluation.DetectionEvaluator):
agnostic_mode=False, iou_type='segm') agnostic_mode=False, iou_type='segm')
mask_metrics, mask_per_category_ap = mask_evaluator.ComputeMetrics( mask_metrics, mask_per_category_ap = mask_evaluator.ComputeMetrics(
include_metrics_per_category=self._include_metrics_per_category, include_metrics_per_category=self._include_metrics_per_category,
super_categories=self._super_categories) super_categories=self._super_categories,
all_metrics_per_category=self._all_metrics_per_category)
mask_metrics.update(mask_per_category_ap) mask_metrics.update(mask_per_category_ap)
mask_metrics = {'DetectionMasks_'+ key: value mask_metrics = {'DetectionMasks_'+ key: value
for key, value in mask_metrics.items()} for key, value in mask_metrics.items()}
......
...@@ -142,6 +142,35 @@ class COCOWrapper(coco.COCO): ...@@ -142,6 +142,35 @@ class COCOWrapper(coco.COCO):
return results return results
COCO_METRIC_NAMES_AND_INDEX = (
('Precision/mAP', 0),
('Precision/mAP@.50IOU', 1),
('Precision/mAP@.75IOU', 2),
('Precision/mAP (small)', 3),
('Precision/mAP (medium)', 4),
('Precision/mAP (large)', 5),
('Recall/AR@1', 6),
('Recall/AR@10', 7),
('Recall/AR@100', 8),
('Recall/AR@100 (small)', 9),
('Recall/AR@100 (medium)', 10),
('Recall/AR@100 (large)', 11)
)
COCO_KEYPOINT_METRIC_NAMES_AND_INDEX = (
('Precision/mAP', 0),
('Precision/mAP@.50IOU', 1),
('Precision/mAP@.75IOU', 2),
('Precision/mAP (medium)', 3),
('Precision/mAP (large)', 4),
('Recall/AR@1', 5),
('Recall/AR@10', 6),
('Recall/AR@100', 7),
('Recall/AR@100 (medium)', 8),
('Recall/AR@100 (large)', 9)
)
class COCOEvalWrapper(cocoeval.COCOeval): class COCOEvalWrapper(cocoeval.COCOeval):
"""Wrapper for the pycocotools COCOeval class. """Wrapper for the pycocotools COCOeval class.
...@@ -259,42 +288,17 @@ class COCOEvalWrapper(cocoeval.COCOeval): ...@@ -259,42 +288,17 @@ class COCOEvalWrapper(cocoeval.COCOeval):
summary_metrics = {} summary_metrics = {}
if self._iou_type in ['bbox', 'segm']: if self._iou_type in ['bbox', 'segm']:
summary_metrics = OrderedDict([('Precision/mAP', self.stats[0]), summary_metrics = OrderedDict(
('Precision/mAP@.50IOU', self.stats[1]), [(name, self.stats[index]) for name, index in
('Precision/mAP@.75IOU', self.stats[2]), COCO_METRIC_NAMES_AND_INDEX])
('Precision/mAP (small)', self.stats[3]),
('Precision/mAP (medium)', self.stats[4]),
('Precision/mAP (large)', self.stats[5]),
('Recall/AR@1', self.stats[6]),
('Recall/AR@10', self.stats[7]),
('Recall/AR@100', self.stats[8]),
('Recall/AR@100 (small)', self.stats[9]),
('Recall/AR@100 (medium)', self.stats[10]),
('Recall/AR@100 (large)', self.stats[11])])
elif self._iou_type == 'keypoints': elif self._iou_type == 'keypoints':
category_id = self.GetCategoryIdList()[0] category_id = self.GetCategoryIdList()[0]
category_name = self.GetCategory(category_id)['name'] category_name = self.GetCategory(category_id)['name']
summary_metrics = OrderedDict([]) summary_metrics = OrderedDict([])
summary_metrics['Precision/mAP ByCategory/{}'.format( for metric_name, index in COCO_KEYPOINT_METRIC_NAMES_AND_INDEX:
category_name)] = self.stats[0] value = self.stats[index]
summary_metrics['Precision/mAP@.50IOU ByCategory/{}'.format( summary_metrics['{} ByCategory/{}'.format(
category_name)] = self.stats[1] metric_name, category_name)] = value
summary_metrics['Precision/mAP@.75IOU ByCategory/{}'.format(
category_name)] = self.stats[2]
summary_metrics['Precision/mAP (medium) ByCategory/{}'.format(
category_name)] = self.stats[3]
summary_metrics['Precision/mAP (large) ByCategory/{}'.format(
category_name)] = self.stats[4]
summary_metrics['Recall/AR@1 ByCategory/{}'.format(
category_name)] = self.stats[5]
summary_metrics['Recall/AR@10 ByCategory/{}'.format(
category_name)] = self.stats[6]
summary_metrics['Recall/AR@100 ByCategory/{}'.format(
category_name)] = self.stats[7]
summary_metrics['Recall/AR@100 (medium) ByCategory/{}'.format(
category_name)] = self.stats[8]
summary_metrics['Recall/AR@100 (large) ByCategory/{}'.format(
category_name)] = self.stats[9]
if not include_metrics_per_category: if not include_metrics_per_category:
return summary_metrics, {} return summary_metrics, {}
if not hasattr(self, 'category_stats'): if not hasattr(self, 'category_stats'):
...@@ -303,48 +307,51 @@ class COCOEvalWrapper(cocoeval.COCOeval): ...@@ -303,48 +307,51 @@ class COCOEvalWrapper(cocoeval.COCOeval):
super_category_ap = OrderedDict([]) super_category_ap = OrderedDict([])
if self.GetAgnosticMode(): if self.GetAgnosticMode():
return summary_metrics, per_category_ap return summary_metrics, per_category_ap
if super_categories:
for key in super_categories:
super_category_ap['PerformanceBySuperCategory/{}'.format(key)] = 0
if all_metrics_per_category:
for metric_name, _ in COCO_METRIC_NAMES_AND_INDEX:
metric_key = '{} BySuperCategory/{}'.format(metric_name, key)
super_category_ap[metric_key] = 0
for category_index, category_id in enumerate(self.GetCategoryIdList()): for category_index, category_id in enumerate(self.GetCategoryIdList()):
category = self.GetCategory(category_id)['name'] category = self.GetCategory(category_id)['name']
# Kept for backward compatilbility # Kept for backward compatilbility
per_category_ap['PerformanceByCategory/mAP/{}'.format( per_category_ap['PerformanceByCategory/mAP/{}'.format(
category)] = self.category_stats[0][category_index] category)] = self.category_stats[0][category_index]
if all_metrics_per_category:
for metric_name, index in COCO_METRIC_NAMES_AND_INDEX:
metric_key = '{} ByCategory/{}'.format(metric_name, category)
per_category_ap[metric_key] = self.category_stats[index][
category_index]
if super_categories: if super_categories:
for key in super_categories: for key in super_categories:
if category in super_categories[key]: if category in super_categories[key]:
metric_name = 'PerformanceBySuperCategory/{}'.format(key) metric_key = 'PerformanceBySuperCategory/{}'.format(key)
if metric_name not in super_category_ap: super_category_ap[metric_key] += self.category_stats[0][
super_category_ap[metric_name] = 0
super_category_ap[metric_name] += self.category_stats[0][
category_index] category_index]
if all_metrics_per_category: if all_metrics_per_category:
per_category_ap['Precision mAP ByCategory/{}'.format( for metric_name, index in COCO_METRIC_NAMES_AND_INDEX:
category)] = self.category_stats[0][category_index] metric_key = '{} BySuperCategory/{}'.format(metric_name, key)
per_category_ap['Precision mAP@.50IOU ByCategory/{}'.format( super_category_ap[metric_key] += (
category)] = self.category_stats[1][category_index] self.category_stats[index][category_index])
per_category_ap['Precision mAP@.75IOU ByCategory/{}'.format(
category)] = self.category_stats[2][category_index]
per_category_ap['Precision mAP (small) ByCategory/{}'.format(
category)] = self.category_stats[3][category_index]
per_category_ap['Precision mAP (medium) ByCategory/{}'.format(
category)] = self.category_stats[4][category_index]
per_category_ap['Precision mAP (large) ByCategory/{}'.format(
category)] = self.category_stats[5][category_index]
per_category_ap['Recall AR@1 ByCategory/{}'.format(
category)] = self.category_stats[6][category_index]
per_category_ap['Recall AR@10 ByCategory/{}'.format(
category)] = self.category_stats[7][category_index]
per_category_ap['Recall AR@100 ByCategory/{}'.format(
category)] = self.category_stats[8][category_index]
per_category_ap['Recall AR@100 (small) ByCategory/{}'.format(
category)] = self.category_stats[9][category_index]
per_category_ap['Recall AR@100 (medium) ByCategory/{}'.format(
category)] = self.category_stats[10][category_index]
per_category_ap['Recall AR@100 (large) ByCategory/{}'.format(
category)] = self.category_stats[11][category_index]
if super_categories: if super_categories:
for key in super_categories: for key in super_categories:
metric_name = 'PerformanceBySuperCategory/{}'.format(key) length = len(super_categories[key])
super_category_ap[metric_name] /= len(super_categories[key]) super_category_ap['PerformanceBySuperCategory/{}'.format(
key)] /= length
if all_metrics_per_category:
for metric_name, _ in COCO_METRIC_NAMES_AND_INDEX:
super_category_ap['{} BySuperCategory/{}'.format(
metric_name, key)] /= length
per_category_ap.update(super_category_ap) per_category_ap.update(super_category_ap)
return summary_metrics, per_category_ap return summary_metrics, per_category_ap
......
...@@ -3,7 +3,7 @@ syntax = "proto2"; ...@@ -3,7 +3,7 @@ syntax = "proto2";
package object_detection.protos; package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py). // Message for configuring DetectionModel evaluation jobs (eval.py).
// Next id - 35 // Next id - 36
message EvalConfig { message EvalConfig {
optional uint32 batch_size = 25 [default = 1]; optional uint32 batch_size = 25 [default = 1];
// Number of visualization images to generate. // Number of visualization images to generate.
...@@ -82,6 +82,9 @@ message EvalConfig { ...@@ -82,6 +82,9 @@ message EvalConfig {
// If True, additionally include per-category metrics. // If True, additionally include per-category metrics.
optional bool include_metrics_per_category = 24 [default = false]; optional bool include_metrics_per_category = 24 [default = false];
// If true, includes all metrics per category.
optional bool all_metrics_per_category = 35 [default=false];
// Optional super-category definitions: keys are super-category names; // Optional super-category definitions: keys are super-category names;
// values are comma-separated categories (assumed to correspond to category // values are comma-separated categories (assumed to correspond to category
// names (`display_name`) in the label map. // names (`display_name`) in the label map.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment