Commit 6f38af5f authored by Ryan Sepassi's avatar Ryan Sepassi Committed by GitHub
Browse files

Fix KL when num_classes != 2

Fix to [issue 1724](https://github.com/tensorflow/models/issues/1724)
parent 6685fb8b
......@@ -212,6 +212,7 @@ def _kl_divergence_with_logits(q_logits, p_logits, weights):
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment