Commit 238922e9 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds per-class metrics.

PiperOrigin-RevId: 366540340
parent 1eddd748
......@@ -101,6 +101,11 @@ class Losses(hyperparams.Config):
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Metrics(hyperparams.Config):
use_per_class_recall: bool = False
@dataclasses.dataclass
class VideoClassificationTask(cfg.TaskConfig):
"""The task config."""
......@@ -109,6 +114,7 @@ class VideoClassificationTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False)
losses: Losses = Losses()
metrics: Metrics = Metrics()
def add_trainer(experiment: cfg.ExperimentConfig,
......
......@@ -154,6 +154,10 @@ class VideoClassificationTask(base_task.Task):
curve='PR',
multi_label=self.task_config.train_data.is_multilabel,
name='PR-AUC'))
if self.task_config.metrics.use_per_class_recall:
for i in range(self.task_config.train_data.num_classes):
metrics.append(
tf.keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment