Commit 5dcfd2c5 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331906632
parent bfeab591
......@@ -63,13 +63,15 @@ class MaskedLMTask(base_task.Task):
model_outputs,
metrics,
aux_losses=None) -> tf.Tensor:
with tf.name_scope('MaskedLMTask/losses'):
metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'],
tf.cast(model_outputs['lm_output'], tf.float32),
from_logits=True)
lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss)
......@@ -128,10 +130,11 @@ class MaskedLMTask(base_task.Task):
return metrics
def process_metrics(self, metrics, labels, model_outputs):
with tf.name_scope('MaskedLMTask/process_metrics'):
metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
model_outputs['lm_output'],
metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_ids'], model_outputs['lm_output'],
labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment