"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "026869915febc078a4ff069b908b1b95aec1275c"
Commit 695c9d58 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334910093
parent 52979660
...@@ -208,6 +208,7 @@ class MaskRCNNTask(cfg.TaskConfig): ...@@ -208,6 +208,7 @@ class MaskRCNNTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
annotation_file: Optional[str] = None annotation_file: Optional[str] = None
gradient_clip_norm: float = 0.0 gradient_clip_norm: float = 0.0
per_category_metrics = False
COCO_INPUT_PATH_BASE = 'coco' COCO_INPUT_PATH_BASE = 'coco'
......
...@@ -129,6 +129,7 @@ class RetinaNetTask(cfg.TaskConfig): ...@@ -129,6 +129,7 @@ class RetinaNetTask(cfg.TaskConfig):
init_checkpoint: Optional[str] = None init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone init_checkpoint_modules: str = 'all' # all or backbone
gradient_clip_norm: float = 0.0 gradient_clip_norm: float = 0.0
per_category_metrics = False
@exp_factory.register_config_factory('retinanet') @exp_factory.register_config_factory('retinanet')
......
...@@ -41,7 +41,11 @@ from official.vision.beta.evaluation import coco_utils ...@@ -41,7 +41,11 @@ from official.vision.beta.evaluation import coco_utils
class COCOEvaluator(object): class COCOEvaluator(object):
"""COCO evaluation metric class.""" """COCO evaluation metric class."""
def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True): def __init__(self,
annotation_file,
include_mask,
need_rescale_bboxes=True,
per_category_metrics=False):
"""Constructs COCO evaluation class. """Constructs COCO evaluation class.
The class provides the interface to COCO metrics_fn. The The class provides the interface to COCO metrics_fn. The
...@@ -57,6 +61,7 @@ class COCOEvaluator(object): ...@@ -57,6 +61,7 @@ class COCOEvaluator(object):
eval. eval.
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
to absolute values (`image_info` is needed in this case). to absolute values (`image_info` is needed in this case).
per_category_metrics: Whether to return per category metrics.
""" """
if annotation_file: if annotation_file:
if annotation_file.startswith('gs://'): if annotation_file.startswith('gs://'):
...@@ -72,6 +77,7 @@ class COCOEvaluator(object): ...@@ -72,6 +77,7 @@ class COCOEvaluator(object):
annotation_file=local_val_json) annotation_file=local_val_json)
self._annotation_file = annotation_file self._annotation_file = annotation_file
self._include_mask = include_mask self._include_mask = include_mask
self._per_category_metrics = per_category_metrics
self._metric_names = [ self._metric_names = [
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10', 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10',
'ARmax100', 'ARs', 'ARm', 'ARl' 'ARmax100', 'ARs', 'ARm', 'ARl'
...@@ -156,6 +162,46 @@ class COCOEvaluator(object): ...@@ -156,6 +162,46 @@ class COCOEvaluator(object):
metrics_dict = {} metrics_dict = {}
for i, name in enumerate(self._metric_names): for i, name in enumerate(self._metric_names):
metrics_dict[name] = metrics[i].astype(np.float32) metrics_dict[name] = metrics[i].astype(np.float32)
# Adds metrics per category.
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'):
for category_index, category_id in enumerate(coco_eval.params.catIds):
metrics_dict['Precision mAP ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[0][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format(
category_id)] = coco_eval.category_stats[1][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format(
category_id)] = coco_eval.category_stats[2][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (small) /{}'.format(
category_id)] = coco_eval.category_stats[3][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (medium) /{}'.format(
category_id)] = coco_eval.category_stats[4][category_index].astype(
np.float32)
metrics_dict['Precision mAP ByCategory (large) /{}'.format(
category_id)] = coco_eval.category_stats[5][category_index].astype(
np.float32)
metrics_dict['Recall AR@1 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[6][category_index].astype(
np.float32)
metrics_dict['Recall AR@10 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[7][category_index].astype(
np.float32)
metrics_dict['Recall AR@100 ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[8][category_index].astype(
np.float32)
metrics_dict['Recall AR (small) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[9][category_index].astype(
np.float32)
metrics_dict['Recall AR (medium) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[10][category_index].astype(
np.float32)
metrics_dict['Recall AR (large) ByCategory/{}'.format(
category_id)] = coco_eval.category_stats[11][category_index].astype(
np.float32)
return metrics_dict return metrics_dict
def _process_predictions(self, predictions): def _process_predictions(self, predictions):
......
...@@ -204,7 +204,8 @@ class MaskRCNNTask(base_task.Task): ...@@ -204,7 +204,8 @@ class MaskRCNNTask(base_task.Task):
else: else:
self.coco_metric = coco_evaluator.COCOEvaluator( self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file, annotation_file=self._task_config.annotation_file,
include_mask=self._task_config.model.include_mask) include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
return metrics return metrics
......
...@@ -178,7 +178,9 @@ class RetinaNetTask(base_task.Task): ...@@ -178,7 +178,9 @@ class RetinaNetTask(base_task.Task):
if not training: if not training:
self.coco_metric = coco_evaluator.COCOEvaluator( self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=None, include_mask=False) annotation_file=self._task_config.annotation_file,
include_mask=False,
per_category_metrics=self._task_config.per_category_metrics)
return metrics return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment