Commit 0aadb8d3 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add per-class precision and recall for image classification task.

PiperOrigin-RevId: 452075537
parent 56761bcb
......@@ -82,6 +82,7 @@ class Losses(hyperparams.Config):
class Evaluation(hyperparams.Config):
top_k: int = 5
precision_and_recall_thresholds: Optional[List[float]] = None
report_per_class_precision_and_recall: bool = False
@dataclasses.dataclass
......
......@@ -201,6 +201,27 @@ class ImageClassificationTask(base_task.Task):
name='recall_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
# Add per-class precision and recall.
if hasattr(
self.task_config.evaluation,
'report_per_class_precision_and_recall'
) and self.task_config.evaluation.report_per_class_precision_and_recall:
for class_id in range(self.task_config.model.num_classes):
metrics += [
tf.keras.metrics.Precision(
thresholds=th,
class_id=class_id,
name=f'precision_at_threshold_{th}/{class_id}',
top_k=1) for th in thresholds
]
metrics += [
tf.keras.metrics.Recall(
thresholds=th,
class_id=class_id,
name=f'recall_at_threshold_{th}/{class_id}',
top_k=1) for th in thresholds
]
# pylint:enable=g-complex-comprehension
else:
metrics = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment