Commit c14a04ab authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add precision and recall metrics at predefined thresholds for image classification task.

PiperOrigin-RevId: 448320923
parent 13642b0f
...@@ -65,6 +65,8 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -65,6 +65,8 @@ class ImageClassificationModel(hyperparams.Config):
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform' kernel_initializer: str = 'random_uniform'
# Whether to output softmax results instead of logits.
output_softmax: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -79,6 +81,7 @@ class Losses(hyperparams.Config): ...@@ -79,6 +81,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class Evaluation(hyperparams.Config): class Evaluation(hyperparams.Config):
top_k: int = 5 top_k: int = 5
precision_and_recall_thresholds: Optional[List[float]] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -184,6 +184,24 @@ class ImageClassificationTask(base_task.Task): ...@@ -184,6 +184,24 @@ class ImageClassificationTask(base_task.Task):
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))]
if hasattr(
self.task_config.evaluation, 'precision_and_recall_thresholds'
) and self.task_config.evaluation.precision_and_recall_thresholds:
thresholds = self.task_config.evaluation.precision_and_recall_thresholds
# pylint:disable=g-complex-comprehension
metrics += [
tf.keras.metrics.Precision(
thresholds=th,
name='precision_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
metrics += [
tf.keras.metrics.Recall(
thresholds=th,
name='recall_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
# pylint:enable=g-complex-comprehension
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
...@@ -234,6 +252,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -234,6 +252,7 @@ class ImageClassificationTask(base_task.Task):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
outputs = model(features, training=True) outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(
...@@ -264,6 +283,11 @@ class ImageClassificationTask(base_task.Task): ...@@ -264,6 +283,11 @@ class ImageClassificationTask(base_task.Task):
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss} logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics: elif model.compiled_metrics:
...@@ -300,6 +324,10 @@ class ImageClassificationTask(base_task.Task): ...@@ -300,6 +324,10 @@ class ImageClassificationTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics: elif model.compiled_metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment