Commit 6e02cb91 authored by Alex Tamkin's avatar Alex Tamkin Committed by Christopher Shallue
Browse files

Add multi-class confusion matrix metrics.

PiperOrigin-RevId: 208862798
parent 87820577
......@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype):
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses):
def _build_metrics(labels, predictions, weights, batch_losses, output_dim=1):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
......@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses):
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert len(predictions.shape) == 2
binary_classification = (predictions.shape[1] == 1)
binary_classification = output_dim == 1
if binary_classification:
assert predictions.shape[1] == 1
predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels")
......@@ -73,15 +75,6 @@ def _build_metrics(labels, predictions, weights, batch_losses):
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss")
# Possibly create additional metrics for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
predicted_labels = tf.cast(predicted_labels, dtype=tf.bool)
# AUC.
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
......@@ -93,14 +86,19 @@ def _build_metrics(labels, predictions, weights, batch_losses):
return count.read_value(), update_op
# Confusion matrix metrics.
metrics["confusion_matrix/true_positives"] = _count_condition(
"true_positives", labels_value=True, predicted_value=True)
metrics["confusion_matrix/false_positives"] = _count_condition(
"false_positives", labels_value=False, predicted_value=True)
metrics["confusion_matrix/true_negatives"] = _count_condition(
"true_negatives", labels_value=False, predicted_value=False)
metrics["confusion_matrix/false_negatives"] = _count_condition(
"false_negatives", labels_value=True, predicted_value=False)
num_labels = 2 if binary_classification else output_dim
for gold_label in range(num_labels):
for pred_label in range(num_labels):
metric_name = "confusion_matrix/label_{}_pred_{}".format(
gold_label, pred_label)
metrics[metric_name] = _count_condition(
metric_name, labels_value=gold_label, predicted_value=pred_label)
# Possibly create AUC metric for binary classification.
if binary_classification:
labels = tf.cast(labels, dtype=tf.bool)
metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000)
return metrics
......@@ -130,7 +128,12 @@ def create_metric_fn(model):
}
def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(labels, predictions, weights, batch_losses)
return _build_metrics(
labels,
predictions,
weights,
batch_losses,
output_dim=model.hparams.output_dim)
return metric_fn, metric_fn_inputs
......
......@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockHparams(object):
"""Mock Hparams class to support accessing with dot notation."""
pass
class _MockModel(object):
"""Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses):
def __init__(self, labels, predictions, weights, batch_losses, output_dim):
self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
self.hparams = _MockHparams()
self.hparams.output_dim = output_dim
class MetricsTest(tf.test.TestCase):
......@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 1,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 4,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 2,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testMultiClassificationWithWeights(self):
......@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3
[0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 2
]
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 1,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2,
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self):
......@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase):
weights = None
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 1,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 1,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 1,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5,
"auc": 0.25,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 2,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 2,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 2,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
def testBinaryClassificationWithWeights(self):
......@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase):
weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses)
model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer()
......@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 1,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 1,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops))
sess.run(update_ops)
......@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1,
"auc": 0,
"confusion_matrix/true_positives": 2,
"confusion_matrix/true_negatives": 0,
"confusion_matrix/false_positives": 2,
"confusion_matrix/false_negatives": 0,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment