Commit 238922e9 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds per-class metrics.

PiperOrigin-RevId: 366540340
parent 1eddd748
...@@ -101,6 +101,11 @@ class Losses(hyperparams.Config): ...@@ -101,6 +101,11 @@ class Losses(hyperparams.Config):
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Metrics(hyperparams.Config):
use_per_class_recall: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class VideoClassificationTask(cfg.TaskConfig): class VideoClassificationTask(cfg.TaskConfig):
"""The task config.""" """The task config."""
...@@ -109,6 +114,7 @@ class VideoClassificationTask(cfg.TaskConfig): ...@@ -109,6 +114,7 @@ class VideoClassificationTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig( validation_data: DataConfig = DataConfig(
is_training=False, drop_remainder=False) is_training=False, drop_remainder=False)
losses: Losses = Losses() losses: Losses = Losses()
metrics: Metrics = Metrics()
def add_trainer(experiment: cfg.ExperimentConfig, def add_trainer(experiment: cfg.ExperimentConfig,
......
...@@ -154,6 +154,10 @@ class VideoClassificationTask(base_task.Task): ...@@ -154,6 +154,10 @@ class VideoClassificationTask(base_task.Task):
curve='PR', curve='PR',
multi_label=self.task_config.train_data.is_multilabel, multi_label=self.task_config.train_data.is_multilabel,
name='PR-AUC')) name='PR-AUC'))
if self.task_config.metrics.use_per_class_recall:
for i in range(self.task_config.train_data.num_classes):
metrics.append(
tf.keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment