Commit 40013f67 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 390235315
parent d4bc6160
......@@ -13,10 +13,10 @@
# limitations under the License.
"""Sentence prediction (classification) task."""
import dataclasses
from typing import List, Union, Optional
from absl import logging
import dataclasses
import numpy as np
import orbit
from scipy import stats
......@@ -140,15 +140,26 @@ class SentencePredictionTask(base_task.Task):
del training
if self.task_config.model.num_classes == 1:
metrics = [tf.keras.metrics.MeanSquaredError()]
elif self.task_config.model.num_classes == 2:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
tf.keras.metrics.AUC(name='auc', curve='PR'),
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
for metric in metrics:
metric.update_state(labels[self.label_field], model_outputs)
if metric.name == 'auc':
# Convert the logit to probability and extract the probability of True..
metric.update_state(
labels[self.label_field],
tf.expand_dims(tf.nn.softmax(model_outputs)[:, 1], axis=1))
if metric.name == 'cls_accuracy':
metric.update_state(labels[self.label_field], model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
compiled_metrics.update_state(labels[self.label_field], model_outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment