Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6e02cb91
Commit
6e02cb91
authored
Aug 15, 2018
by
Alex Tamkin
Committed by
Christopher Shallue
Oct 16, 2018
Browse files
Add multi-class confusion matrix metrics.
PiperOrigin-RevId: 208862798
parent
87820577
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
127 additions
and
52 deletions
+127
-52
research/astronet/astronet/ops/metrics.py
research/astronet/astronet/ops/metrics.py
+30
-27
research/astronet/astronet/ops/metrics_test.py
research/astronet/astronet/ops/metrics_test.py
+97
-25
No files found.
research/astronet/astronet/ops/metrics.py
View file @
6e02cb91
...
...
@@ -30,7 +30,7 @@ def _metric_variable(name, shape, dtype):
collections
=
[
tf
.
GraphKeys
.
LOCAL_VARIABLES
,
tf
.
GraphKeys
.
METRIC_VARIABLES
])
def
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
):
def
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
):
"""Builds TensorFlow operations to compute model evaluation metrics.
Args:
...
...
@@ -38,14 +38,16 @@ def _build_metrics(labels, predictions, weights, batch_losses):
predictions: Tensor with shape [batch_size, output_dim].
weights: Tensor with shape [batch_size].
batch_losses: Tensor with shape [batch_size].
output_dim: Dimension of model output
Returns:
A dictionary {metric_name: (metric_value, update_op).
"""
# Compute the predicted labels.
assert
len
(
predictions
.
shape
)
==
2
binary_classification
=
(
predictions
.
shape
[
1
]
==
1
)
binary_classification
=
output_dim
==
1
if
binary_classification
:
assert
predictions
.
shape
[
1
]
==
1
predictions
=
tf
.
squeeze
(
predictions
,
axis
=
[
1
])
predicted_labels
=
tf
.
to_int32
(
tf
.
greater
(
predictions
,
0.5
),
name
=
"predicted_labels"
)
...
...
@@ -73,35 +75,31 @@ def _build_metrics(labels, predictions, weights, batch_losses):
metrics
[
"losses/weighted_cross_entropy"
]
=
tf
.
metrics
.
mean
(
batch_losses
,
weights
=
weights
,
name
=
"cross_entropy_loss"
)
# Possibly create additional metrics for binary classification.
def
_count_condition
(
name
,
labels_value
,
predicted_value
):
"""Creates a counter for given values of predictions and labels."""
count
=
_metric_variable
(
name
,
[],
tf
.
float32
)
is_equal
=
tf
.
to_float
(
tf
.
logical_and
(
tf
.
equal
(
labels
,
labels_value
),
tf
.
equal
(
predicted_labels
,
predicted_value
)))
update_op
=
tf
.
assign_add
(
count
,
tf
.
reduce_sum
(
weights
*
is_equal
))
return
count
.
read_value
(),
update_op
# Confusion matrix metrics.
num_labels
=
2
if
binary_classification
else
output_dim
for
gold_label
in
range
(
num_labels
):
for
pred_label
in
range
(
num_labels
):
metric_name
=
"confusion_matrix/label_{}_pred_{}"
.
format
(
gold_label
,
pred_label
)
metrics
[
metric_name
]
=
_count_condition
(
metric_name
,
labels_value
=
gold_label
,
predicted_value
=
pred_label
)
# Possibly create AUC metric for binary classification.
if
binary_classification
:
labels
=
tf
.
cast
(
labels
,
dtype
=
tf
.
bool
)
predicted_labels
=
tf
.
cast
(
predicted_labels
,
dtype
=
tf
.
bool
)
# AUC.
metrics
[
"auc"
]
=
tf
.
metrics
.
auc
(
labels
,
predictions
,
weights
=
weights
,
num_thresholds
=
1000
)
def
_count_condition
(
name
,
labels_value
,
predicted_value
):
"""Creates a counter for given values of predictions and labels."""
count
=
_metric_variable
(
name
,
[],
tf
.
float32
)
is_equal
=
tf
.
to_float
(
tf
.
logical_and
(
tf
.
equal
(
labels
,
labels_value
),
tf
.
equal
(
predicted_labels
,
predicted_value
)))
update_op
=
tf
.
assign_add
(
count
,
tf
.
reduce_sum
(
weights
*
is_equal
))
return
count
.
read_value
(),
update_op
# Confusion matrix metrics.
metrics
[
"confusion_matrix/true_positives"
]
=
_count_condition
(
"true_positives"
,
labels_value
=
True
,
predicted_value
=
True
)
metrics
[
"confusion_matrix/false_positives"
]
=
_count_condition
(
"false_positives"
,
labels_value
=
False
,
predicted_value
=
True
)
metrics
[
"confusion_matrix/true_negatives"
]
=
_count_condition
(
"true_negatives"
,
labels_value
=
False
,
predicted_value
=
False
)
metrics
[
"confusion_matrix/false_negatives"
]
=
_count_condition
(
"false_negatives"
,
labels_value
=
True
,
predicted_value
=
False
)
return
metrics
...
...
@@ -130,7 +128,12 @@ def create_metric_fn(model):
}
def
metric_fn
(
labels
,
predictions
,
weights
,
batch_losses
):
return
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
)
return
_build_metrics
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
model
.
hparams
.
output_dim
)
return
metric_fn
,
metric_fn_inputs
...
...
research/astronet/astronet/ops/metrics_test.py
View file @
6e02cb91
...
...
@@ -30,15 +30,23 @@ def _unpack_metric_map(names_to_tuples):
return
dict
(
zip
(
metric_names
,
value_ops
)),
dict
(
zip
(
metric_names
,
update_ops
))
class
_MockHparams
(
object
):
"""Mock Hparams class to support accessing with dot notation."""
pass
class
_MockModel
(
object
):
"""Mock model for testing."""
def
__init__
(
self
,
labels
,
predictions
,
weights
,
batch_losses
):
def
__init__
(
self
,
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
):
self
.
labels
=
tf
.
constant
(
labels
,
dtype
=
tf
.
int32
)
self
.
predictions
=
tf
.
constant
(
predictions
,
dtype
=
tf
.
float32
)
self
.
weights
=
None
if
weights
is
None
else
tf
.
constant
(
weights
,
dtype
=
tf
.
float32
)
self
.
batch_losses
=
tf
.
constant
(
batch_losses
,
dtype
=
tf
.
float32
)
self
.
hparams
=
_MockHparams
()
self
.
hparams
.
output_dim
=
output_dim
class
MetricsTest
(
tf
.
test
.
TestCase
):
...
...
@@ -48,13 +56,13 @@ class MetricsTest(tf.test.TestCase):
predictions
=
[
[
0.7
,
0.2
,
0.1
,
0.0
],
# Predicted label = 0
[
0.2
,
0.4
,
0.2
,
0.2
],
# Predicted label = 1
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label =
4
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label =
3
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label =
3
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label =
2
]
weights
=
None
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
)
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
4
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
...
...
@@ -68,6 +76,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"confusion_matrix/label_0_pred_0"
:
1
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
1
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
1
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
1
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
...
...
@@ -76,6 +100,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct"
:
4
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"confusion_matrix/label_0_pred_0"
:
2
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
2
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
2
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
2
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
def
testMultiClassificationWithWeights
(
self
):
...
...
@@ -83,13 +123,13 @@ class MetricsTest(tf.test.TestCase):
predictions
=
[
[
0.7
,
0.2
,
0.1
,
0.0
],
# Predicted label = 0
[
0.2
,
0.4
,
0.2
,
0.2
],
# Predicted label = 1
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label =
4
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label =
3
[
0.0
,
0.0
,
0.0
,
1.0
],
# Predicted label =
3
[
0.1
,
0.1
,
0.7
,
0.1
],
# Predicted label =
2
]
weights
=
[
0
,
1
,
0
,
1
]
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
)
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
4
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
...
...
@@ -103,6 +143,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct"
:
1
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
1
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
0
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
1
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
...
...
@@ -111,6 +167,22 @@ class MetricsTest(tf.test.TestCase):
"accuracy/num_correct"
:
2
,
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"confusion_matrix/label_0_pred_0"
:
0
,
"confusion_matrix/label_0_pred_1"
:
0
,
"confusion_matrix/label_0_pred_2"
:
0
,
"confusion_matrix/label_0_pred_3"
:
0
,
"confusion_matrix/label_1_pred_0"
:
0
,
"confusion_matrix/label_1_pred_1"
:
2
,
"confusion_matrix/label_1_pred_2"
:
0
,
"confusion_matrix/label_1_pred_3"
:
0
,
"confusion_matrix/label_2_pred_0"
:
0
,
"confusion_matrix/label_2_pred_1"
:
0
,
"confusion_matrix/label_2_pred_2"
:
0
,
"confusion_matrix/label_2_pred_3"
:
0
,
"confusion_matrix/label_3_pred_0"
:
0
,
"confusion_matrix/label_3_pred_1"
:
0
,
"confusion_matrix/label_3_pred_2"
:
2
,
"confusion_matrix/label_3_pred_3"
:
0
},
sess
.
run
(
value_ops
))
def
testBinaryClassificationWithoutWeights
(
self
):
...
...
@@ -124,7 +196,7 @@ class MetricsTest(tf.test.TestCase):
weights
=
None
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
)
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
...
...
@@ -139,10 +211,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"auc"
:
0.25
,
"confusion_matrix/
true_positives
"
:
1
,
"confusion_matrix/
true_negatives
"
:
1
,
"confusion_matrix/
false_positives
"
:
1
,
"confusion_matrix/
false_negatives
"
:
1
,
"confusion_matrix/
label_0_pred_0
"
:
1
,
"confusion_matrix/
label_0_pred_1
"
:
1
,
"confusion_matrix/
label_1_pred_0
"
:
1
,
"confusion_matrix/
label_1_pred_1
"
:
1
,
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
...
...
@@ -152,10 +224,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1.5
,
"auc"
:
0.25
,
"confusion_matrix/
true_positives
"
:
2
,
"confusion_matrix/
true_negatives
"
:
2
,
"confusion_matrix/
false_positives
"
:
2
,
"confusion_matrix/
false_negatives
"
:
2
,
"confusion_matrix/
label_0_pred_0
"
:
2
,
"confusion_matrix/
label_0_pred_1
"
:
2
,
"confusion_matrix/
label_1_pred_0
"
:
2
,
"confusion_matrix/
label_1_pred_1
"
:
2
,
},
sess
.
run
(
value_ops
))
def
testBinaryClassificationWithWeights
(
self
):
...
...
@@ -169,7 +241,7 @@ class MetricsTest(tf.test.TestCase):
weights
=
[
0
,
1
,
0
,
1
]
batch_losses
=
[
0
,
0
,
4
,
2
]
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
)
model
=
_MockModel
(
labels
,
predictions
,
weights
,
batch_losses
,
output_dim
=
1
)
metric_map
=
metrics
.
create_metrics
(
model
)
value_ops
,
update_ops
=
_unpack_metric_map
(
metric_map
)
initializer
=
tf
.
local_variables_initializer
()
...
...
@@ -184,10 +256,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"auc"
:
0
,
"confusion_matrix/
true_positives
"
:
1
,
"confusion_matrix/
true_negatives
"
:
0
,
"confusion_matrix/
false_positives
"
:
1
,
"confusion_matrix/
false_negatives
"
:
0
,
"confusion_matrix/
label_0_pred_0
"
:
0
,
"confusion_matrix/
label_0_pred_1
"
:
1
,
"confusion_matrix/
label_1_pred_0
"
:
0
,
"confusion_matrix/
label_1_pred_1
"
:
1
,
},
sess
.
run
(
value_ops
))
sess
.
run
(
update_ops
)
...
...
@@ -197,10 +269,10 @@ class MetricsTest(tf.test.TestCase):
"accuracy/accuracy"
:
0.5
,
"losses/weighted_cross_entropy"
:
1
,
"auc"
:
0
,
"confusion_matrix/
true_positives
"
:
2
,
"confusion_matrix/
true_negatives
"
:
0
,
"confusion_matrix/
false_positives
"
:
2
,
"confusion_matrix/
false_negatives
"
:
0
,
"confusion_matrix/
label_0_pred_0
"
:
0
,
"confusion_matrix/
label_0_pred_1
"
:
2
,
"confusion_matrix/
label_1_pred_0
"
:
0
,
"confusion_matrix/
label_1_pred_1
"
:
2
,
},
sess
.
run
(
value_ops
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment