Commit 4c21dfa8 authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix github issue #3269 where the accuracy is wrongly underestimated for binary classification.

PiperOrigin-RevId: 186265033
parent af6527c9
...@@ -254,7 +254,7 @@ def predictions(logits): ...@@ -254,7 +254,7 @@ def predictions(logits):
with tf.name_scope('predictions'): with tf.name_scope('predictions'):
# For binary classification # For binary classification
if inner_dim == 1: if inner_dim == 1:
pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.5), tf.int64) pred = tf.cast(tf.greater(tf.squeeze(logits, -1), 0.), tf.int64)
# For multi-class classification # For multi-class classification
else: else:
pred = tf.argmax(logits, 2) pred = tf.argmax(logits, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment