"vscode:/vscode.git/clone" did not exist on "d48a34c39ebc4befece5febc5dd62b548563fa9e"
Commit 95dc9045 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix MLM accuracy bug. We should not just take reduce mean. The sum of weights...

Fix MLM accuracy bug. We should not just take reduce mean. The sum of weights should be the denominator.

PiperOrigin-RevId: 280002181
parent 0177deeb
...@@ -172,7 +172,9 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -172,7 +172,9 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Adds metrics.""" """Adds metrics."""
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy( masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
lm_labels, lm_output) lm_labels, lm_output)
masked_lm_accuracy = tf.reduce_mean(masked_lm_accuracy * lm_label_weights) numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
denominator = tf.reduce_sum(lm_label_weights) + 1e-5
masked_lm_accuracy = numerator / denominator
self.add_metric( self.add_metric(
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean') masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment