Commit 5dcfd2c5 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331906632
parent bfeab591
...@@ -63,13 +63,15 @@ class MaskedLMTask(base_task.Task): ...@@ -63,13 +63,15 @@ class MaskedLMTask(base_task.Task):
model_outputs, model_outputs,
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
with tf.name_scope('MaskedLMTask/losses'):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy( lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels['masked_lm_ids'], labels['masked_lm_ids'],
tf.cast(model_outputs['lm_output'], tf.float32), tf.cast(model_outputs['lm_output'], tf.float32),
from_logits=True) from_logits=True)
lm_label_weights = labels['masked_lm_weights'] lm_label_weights = labels['masked_lm_weights']
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights) lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
lm_label_weights)
lm_denominator_loss = tf.reduce_sum(lm_label_weights) lm_denominator_loss = tf.reduce_sum(lm_label_weights)
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss) mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
metrics['lm_example_loss'].update_state(mlm_loss) metrics['lm_example_loss'].update_state(mlm_loss)
...@@ -128,10 +130,11 @@ class MaskedLMTask(base_task.Task): ...@@ -128,10 +130,11 @@ class MaskedLMTask(base_task.Task):
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
with tf.name_scope('MaskedLMTask/process_metrics'):
metrics = dict([(metric.name, metric) for metric in metrics]) metrics = dict([(metric.name, metric) for metric in metrics])
if 'masked_lm_accuracy' in metrics: if 'masked_lm_accuracy' in metrics:
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'], metrics['masked_lm_accuracy'].update_state(
model_outputs['lm_output'], labels['masked_lm_ids'], model_outputs['lm_output'],
labels['masked_lm_weights']) labels['masked_lm_weights'])
if 'next_sentence_accuracy' in metrics: if 'next_sentence_accuracy' in metrics:
metrics['next_sentence_accuracy'].update_state( metrics['next_sentence_accuracy'].update_state(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment