"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "83a9b5271a68c7d1f8443f91c8d8b7d24ab581a9"
Unverified Commit 3d174546 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

report per class metrics separately

parent 1edd6e86
...@@ -121,6 +121,7 @@ class PanopticQualityEvaluator(hyperparams.Config): ...@@ -121,6 +121,7 @@ class PanopticQualityEvaluator(hyperparams.Config):
offset: int = 256 * 256 * 256 offset: int = 256 * 256 * 256
is_thing: List[float] = dataclasses.field( is_thing: List[float] = dataclasses.field(
default_factory=list) default_factory=list)
report_per_class_metrics: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -427,7 +427,17 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -427,7 +427,17 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()}) result.update({'segmentation_mean_iou': tf.reduce_mean(ious).numpy()})
if self.task_config.model.generate_panoptic_masks: if self.task_config.model.generate_panoptic_masks:
for k, value in self.panoptic_quality_metric.result().items(): report_per_class_metrics = self.task_config.panoptic_quality_evaluator.report_per_class_metrics
result['panoptic_quality/' + k] = value panoptic_quality_results = self.panoptic_quality_metric.result()
for k, value in panoptic_quality_results.items():
if k.endswith('per_class'):
if report_per_class_metrics:
for i, per_class_value in enumerate(value):
metric_key = 'panoptic_quality/{}/class_{}'.format(k, i)
result[metric_key] = per_class_value
else:
continue
else:
result['panoptic_quality/{}'.format(k)] = value
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment