Commit 5e5e0706 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove MaskedLMAccuracy and use SparseCategoricalAccuracy.

PiperOrigin-RevId: 315402552
parent 4c77f5b1
......@@ -35,21 +35,6 @@ class MaskedLMConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig()
class MaskedLMAccuracy(tf.keras.metrics.Mean):
"""The weighted accuracy metric for the masked language model."""
def __init__(self, name=None, dtype=None):
super(MaskedLMAccuracy, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
y_true, y_pred)
numerator = tf.reduce_sum(masked_lm_accuracy * sample_weight)
denominator = tf.reduce_sum(sample_weight) + 1e-5
masked_lm_accuracy = numerator / denominator
return super(MaskedLMAccuracy, self).update_state(masked_lm_accuracy)
@base_task.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task):
"""Mock task object for testing."""
......@@ -111,7 +96,7 @@ class MaskedLMTask(base_task.Task):
def build_metrics(self, training=None):
del training
metrics = [
MaskedLMAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss')
]
# TODO(hongkuny): rethink how to manage metrics creation with heads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment