Commit 5e5e0706 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove MaskedLMAccuracy and use SparseCategoricalAccuracy.

PiperOrigin-RevId: 315402552
parent 4c77f5b1
...@@ -35,21 +35,6 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -35,21 +35,6 @@ class MaskedLMConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
class MaskedLMAccuracy(tf.keras.metrics.Mean):
"""The weighted accuracy metric for the masked language model."""
def __init__(self, name=None, dtype=None):
super(MaskedLMAccuracy, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
y_true, y_pred)
numerator = tf.reduce_sum(masked_lm_accuracy * sample_weight)
denominator = tf.reduce_sum(sample_weight) + 1e-5
masked_lm_accuracy = numerator / denominator
return super(MaskedLMAccuracy, self).update_state(masked_lm_accuracy)
@base_task.register_task_cls(MaskedLMConfig) @base_task.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Mock task object for testing."""
...@@ -111,7 +96,7 @@ class MaskedLMTask(base_task.Task): ...@@ -111,7 +96,7 @@ class MaskedLMTask(base_task.Task):
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
metrics = [ metrics = [
MaskedLMAccuracy(name='masked_lm_accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
tf.keras.metrics.Mean(name='lm_example_loss') tf.keras.metrics.Mean(name='lm_example_loss')
] ]
# TODO(hongkuny): rethink how to manage metrics creation with heads. # TODO(hongkuny): rethink how to manage metrics creation with heads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment