Commit 6e02cb91 authored by Alex Tamkin's avatar Alex Tamkin Committed by Christopher Shallue
Browse files

Add multi-class confusion matrix metrics.

PiperOrigin-RevId: 208862798
parent 87820577
...@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype): ...@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype):
collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES]) collections=[tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.METRIC_VARIABLES])
def _build_metrics(labels, predictions, weights, batch_losses): def _build_metrics(labels, predictions, weights, batch_losses, output_dim=1):
"""Builds TensorFlow operations to compute model evaluation metrics. """Builds TensorFlow operations to compute model evaluation metrics.
Args: Args:
...@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses): ...@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses):
predictions: Tensor with shape [batch_size, output_dim]. predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size]. weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size]. batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns: Returns:
A dictionary {metric_name: (metric_value, update_op). A dictionary {metric_name: (metric_value, update_op).
""" """
# Compute the predicted labels. # Compute the predicted labels.
assert len(predictions.shape) == 2 assert len(predictions.shape) == 2
binary_classification = (predictions.shape[1] == 1) binary_classification = output_dim == 1
if binary_classification: if binary_classification:
assert predictions.shape[1] == 1
predictions = tf.squeeze(predictions, axis=[1]) predictions = tf.squeeze(predictions, axis=[1])
predicted_labels = tf.to_int32( predicted_labels = tf.to_int32(
tf.greater(predictions, 0.5), name="predicted_labels") tf.greater(predictions, 0.5), name="predicted_labels")
...@@ -73,35 +75,31 @@ def _build_metrics(labels, predictions, weights, batch_losses): ...@@ -73,35 +75,31 @@ def _build_metrics(labels, predictions, weights, batch_losses):
metrics["losses/weighted_cross_entropy"] = tf.metrics.mean( metrics["losses/weighted_cross_entropy"] = tf.metrics.mean(
batch_losses, weights=weights, name="cross_entropy_loss") batch_losses, weights=weights, name="cross_entropy_loss")
# Possibly create additional metrics for binary classification. def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
num_labels = 2 if binary_classification else output_dim
for gold_label in range(num_labels):
for pred_label in range(num_labels):
metric_name = "confusion_matrix/label_{}_pred_{}".format(
gold_label, pred_label)
metrics[metric_name] = _count_condition(
metric_name, labels_value=gold_label, predicted_value=pred_label)
# Possibly create AUC metric for binary classification.
if binary_classification: if binary_classification:
labels = tf.cast(labels, dtype=tf.bool) labels = tf.cast(labels, dtype=tf.bool)
predicted_labels = tf.cast(predicted_labels, dtype=tf.bool)
# AUC.
metrics["auc"] = tf.metrics.auc( metrics["auc"] = tf.metrics.auc(
labels, predictions, weights=weights, num_thresholds=1000) labels, predictions, weights=weights, num_thresholds=1000)
def _count_condition(name, labels_value, predicted_value):
"""Creates a counter for given values of predictions and labels."""
count = _metric_variable(name, [], tf.float32)
is_equal = tf.to_float(
tf.logical_and(
tf.equal(labels, labels_value),
tf.equal(predicted_labels, predicted_value)))
update_op = tf.assign_add(count, tf.reduce_sum(weights * is_equal))
return count.read_value(), update_op
# Confusion matrix metrics.
metrics["confusion_matrix/true_positives"] = _count_condition(
"true_positives", labels_value=True, predicted_value=True)
metrics["confusion_matrix/false_positives"] = _count_condition(
"false_positives", labels_value=False, predicted_value=True)
metrics["confusion_matrix/true_negatives"] = _count_condition(
"true_negatives", labels_value=False, predicted_value=False)
metrics["confusion_matrix/false_negatives"] = _count_condition(
"false_negatives", labels_value=True, predicted_value=False)
return metrics return metrics
...@@ -130,7 +128,12 @@ def create_metric_fn(model): ...@@ -130,7 +128,12 @@ def create_metric_fn(model):
} }
def metric_fn(labels, predictions, weights, batch_losses): def metric_fn(labels, predictions, weights, batch_losses):
return _build_metrics(labels, predictions, weights, batch_losses) return _build_metrics(
labels,
predictions,
weights,
batch_losses,
output_dim=model.hparams.output_dim)
return metric_fn, metric_fn_inputs return metric_fn, metric_fn_inputs
......
...@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples): ...@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples):
return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops))
class _MockHparams(object):
"""Mock Hparams class to support accessing with dot notation."""
pass
class _MockModel(object): class _MockModel(object):
"""Mock model for testing.""" """Mock model for testing."""
def __init__(self, labels, predictions, weights, batch_losses): def __init__(self, labels, predictions, weights, batch_losses, output_dim):
self.labels = tf.constant(labels, dtype=tf.int32) self.labels = tf.constant(labels, dtype=tf.int32)
self.predictions = tf.constant(predictions, dtype=tf.float32) self.predictions = tf.constant(predictions, dtype=tf.float32)
self.weights = None if weights is None else tf.constant( self.weights = None if weights is None else tf.constant(
weights, dtype=tf.float32) weights, dtype=tf.float32)
self.batch_losses = tf.constant(batch_losses, dtype=tf.float32) self.batch_losses = tf.constant(batch_losses, dtype=tf.float32)
self.hparams = _MockHparams()
self.hparams.output_dim = output_dim
class MetricsTest(tf.test.TestCase): class MetricsTest(tf.test.TestCase):
...@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase): ...@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [ predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0 [0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1 [0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4 [0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3 [0.1, 0.1, 0.7, 0.1], # Predicted label = 2
] ]
weights = None weights = None
batch_losses = [0, 0, 4, 2] batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses) model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model) metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map) value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer() initializer = tf.local_variables_initializer()
...@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase): ...@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2, "accuracy/num_correct": 2,
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5, "losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 1,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops)) }, sess.run(value_ops))
sess.run(update_ops) sess.run(update_ops)
...@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase): ...@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 4, "accuracy/num_correct": 4,
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5, "losses/weighted_cross_entropy": 1.5,
"confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 2,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops)) }, sess.run(value_ops))
def testMultiClassificationWithWeights(self): def testMultiClassificationWithWeights(self):
...@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase): ...@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase):
predictions = [ predictions = [
[0.7, 0.2, 0.1, 0.0], # Predicted label = 0 [0.7, 0.2, 0.1, 0.0], # Predicted label = 0
[0.2, 0.4, 0.2, 0.2], # Predicted label = 1 [0.2, 0.4, 0.2, 0.2], # Predicted label = 1
[0.0, 0.0, 0.0, 1.0], # Predicted label = 4 [0.0, 0.0, 0.0, 1.0], # Predicted label = 3
[0.1, 0.1, 0.7, 0.1], # Predicted label = 3 [0.1, 0.1, 0.7, 0.1], # Predicted label = 2
] ]
weights = [0, 1, 0, 1] weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2] batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses) model = _MockModel(labels, predictions, weights, batch_losses, output_dim=4)
metric_map = metrics.create_metrics(model) metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map) value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer() initializer = tf.local_variables_initializer()
...@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase): ...@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 1, "accuracy/num_correct": 1,
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1, "losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 1,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 1,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops)) }, sess.run(value_ops))
sess.run(update_ops) sess.run(update_ops)
...@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase): ...@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct": 2, "accuracy/num_correct": 2,
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1, "losses/weighted_cross_entropy": 1,
"confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/label_0_pred_1": 0,
"confusion_matrix/label_0_pred_2": 0,
"confusion_matrix/label_0_pred_3": 0,
"confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/label_1_pred_1": 2,
"confusion_matrix/label_1_pred_2": 0,
"confusion_matrix/label_1_pred_3": 0,
"confusion_matrix/label_2_pred_0": 0,
"confusion_matrix/label_2_pred_1": 0,
"confusion_matrix/label_2_pred_2": 0,
"confusion_matrix/label_2_pred_3": 0,
"confusion_matrix/label_3_pred_0": 0,
"confusion_matrix/label_3_pred_1": 0,
"confusion_matrix/label_3_pred_2": 2,
"confusion_matrix/label_3_pred_3": 0
}, sess.run(value_ops)) }, sess.run(value_ops))
def testBinaryClassificationWithoutWeights(self): def testBinaryClassificationWithoutWeights(self):
...@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase): ...@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase):
weights = None weights = None
batch_losses = [0, 0, 4, 2] batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses) model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model) metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map) value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer() initializer = tf.local_variables_initializer()
...@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase): ...@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5, "losses/weighted_cross_entropy": 1.5,
"auc": 0.25, "auc": 0.25,
"confusion_matrix/true_positives": 1, "confusion_matrix/label_0_pred_0": 1,
"confusion_matrix/true_negatives": 1, "confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/false_positives": 1, "confusion_matrix/label_1_pred_0": 1,
"confusion_matrix/false_negatives": 1, "confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops)) }, sess.run(value_ops))
sess.run(update_ops) sess.run(update_ops)
...@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase): ...@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1.5, "losses/weighted_cross_entropy": 1.5,
"auc": 0.25, "auc": 0.25,
"confusion_matrix/true_positives": 2, "confusion_matrix/label_0_pred_0": 2,
"confusion_matrix/true_negatives": 2, "confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/false_positives": 2, "confusion_matrix/label_1_pred_0": 2,
"confusion_matrix/false_negatives": 2, "confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops)) }, sess.run(value_ops))
def testBinaryClassificationWithWeights(self): def testBinaryClassificationWithWeights(self):
...@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase): ...@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase):
weights = [0, 1, 0, 1] weights = [0, 1, 0, 1]
batch_losses = [0, 0, 4, 2] batch_losses = [0, 0, 4, 2]
model = _MockModel(labels, predictions, weights, batch_losses) model = _MockModel(labels, predictions, weights, batch_losses, output_dim=1)
metric_map = metrics.create_metrics(model) metric_map = metrics.create_metrics(model)
value_ops, update_ops = _unpack_metric_map(metric_map) value_ops, update_ops = _unpack_metric_map(metric_map)
initializer = tf.local_variables_initializer() initializer = tf.local_variables_initializer()
...@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase): ...@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1, "losses/weighted_cross_entropy": 1,
"auc": 0, "auc": 0,
"confusion_matrix/true_positives": 1, "confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/true_negatives": 0, "confusion_matrix/label_0_pred_1": 1,
"confusion_matrix/false_positives": 1, "confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/false_negatives": 0, "confusion_matrix/label_1_pred_1": 1,
}, sess.run(value_ops)) }, sess.run(value_ops))
sess.run(update_ops) sess.run(update_ops)
...@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase): ...@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy": 0.5, "accuracy/accuracy": 0.5,
"losses/weighted_cross_entropy": 1, "losses/weighted_cross_entropy": 1,
"auc": 0, "auc": 0,
"confusion_matrix/true_positives": 2, "confusion_matrix/label_0_pred_0": 0,
"confusion_matrix/true_negatives": 0, "confusion_matrix/label_0_pred_1": 2,
"confusion_matrix/false_positives": 2, "confusion_matrix/label_1_pred_0": 0,
"confusion_matrix/false_negatives": 0, "confusion_matrix/label_1_pred_1": 2,
}, sess.run(value_ops)) }, sess.run(value_ops))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment