"cacheflow/vscode:/vscode.git/clone" did not exist on "4f6f4967f6af78534f460d75a9391f9a42b564b0"
Commit 695c9d58 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334910093
parent 52979660
......@@ -208,6 +208,7 @@ class MaskRCNNTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone
annotation_file: Optional[str] = None
gradient_clip_norm: float = 0.0
per_category_metrics = False
COCO_INPUT_PATH_BASE = 'coco'
......
......@@ -129,6 +129,7 @@ class RetinaNetTask(cfg.TaskConfig):
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
gradient_clip_norm: float = 0.0
per_category_metrics = False
@exp_factory.register_config_factory('retinanet')
......
......@@ -41,7 +41,11 @@ from official.vision.beta.evaluation import coco_utils
class COCOEvaluator(object):
"""COCO evaluation metric class."""
def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True):
def __init__(self,
annotation_file,
include_mask,
need_rescale_bboxes=True,
per_category_metrics=False):
"""Constructs COCO evaluation class.
The class provides the interface to COCO metrics_fn. The
......@@ -57,6 +61,7 @@ class COCOEvaluator(object):
eval.
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
to absolute values (`image_info` is needed in this case).
per_category_metrics: Whether to return per category metrics.
"""
if annotation_file:
if annotation_file.startswith('gs://'):
......@@ -72,6 +77,7 @@ class COCOEvaluator(object):
annotation_file=local_val_json)
self._annotation_file = annotation_file
self._include_mask = include_mask
self._per_category_metrics = per_category_metrics
self._metric_names = [
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10',
'ARmax100', 'ARs', 'ARm', 'ARl'
......@@ -156,6 +162,46 @@ class COCOEvaluator(object):
metrics_dict = {}
for i, name in enumerate(self._metric_names):
metrics_dict[name] = metrics[i].astype(np.float32)
# Adds metrics per category.
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'):
for category_index, category_id in enumerate(coco_eval.params.catIds):
metrics_dict['Precision mAP ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[0][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format(
category_id)] = coco_eval.category_stats[1][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format(
category_id)] = coco_eval.category_stats[2][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (small) /{}'.format(
category_id)] = coco_eval.category_stats[3][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (medium) /{}'.format(
category_id)] = coco_eval.category_stats[4][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (large) /{}'.format(
category_id)] = coco_eval.category_stats[5][category_index].astype(
np.float32)
metrics_dict['Recall AR@1 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[6][category_index].astype(
np.float32)
metrics_dict['Recall AR@10 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[7][category_index].astype(
np.float32)
metrics_dict['Recall AR@100 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[8][category_index].astype(
np.float32)
metrics_dict['Recall AR (small) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[9][category_index].astype(
np.float32)
metrics_dict['Recall AR (medium) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[10][category_index].astype(
np.float32)
metrics_dict['Recall AR (large) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[11][category_index].astype(
np.float32)
return metrics_dict
def _process_predictions(self, predictions):
......
......@@ -204,7 +204,8 @@ class MaskRCNNTask(base_task.Task):
else:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file,
include_mask=self._task_config.model.include_mask)
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
......
......@@ -178,7 +178,9 @@ class RetinaNetTask(base_task.Task):
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=None, include_mask=False)
annotation_file=self._task_config.annotation_file,
include_mask=False,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment