Commit 5dcfd2c5 authored by Terry Huang's avatar Terry Huang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 331906632
parent bfeab591
...@@ -63,31 +63,33 @@ class MaskedLMTask(base_task.Task): ...@@ -63,31 +63,33 @@ class MaskedLMTask(base_task.Task):
model_outputs, model_outputs,
metrics, metrics,
aux_losses=None) -> tf.Tensor: aux_losses=None) -> tf.Tensor:
metrics = dict([(metric.name, metric) for metric in metrics]) with tf.name_scope('MaskedLMTask/losses'):
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy( metrics = dict([(metric.name, metric) for metric in metrics])
labels['masked_lm_ids'], lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
tf.cast(model_outputs['lm_output'], tf.float32), labels['masked_lm_ids'],
from_logits=True) tf.cast(model_outputs['lm_output'], tf.float32),
lm_label_weights = labels['masked_lm_weights'] from_logits=True)
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights) lm_label_weights = labels['masked_lm_weights']
lm_denominator_loss = tf.reduce_sum(lm_label_weights) lm_numerator_loss = tf.reduce_sum(lm_prediction_losses *
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss) lm_label_weights)
metrics['lm_example_loss'].update_state(mlm_loss) lm_denominator_loss = tf.reduce_sum(lm_label_weights)
if 'next_sentence_labels' in labels: mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
sentence_labels = labels['next_sentence_labels'] metrics['lm_example_loss'].update_state(mlm_loss)
sentence_outputs = tf.cast( if 'next_sentence_labels' in labels:
model_outputs['next_sentence'], dtype=tf.float32) sentence_labels = labels['next_sentence_labels']
sentence_loss = tf.reduce_mean( sentence_outputs = tf.cast(
tf.keras.losses.sparse_categorical_crossentropy( model_outputs['next_sentence'], dtype=tf.float32)
sentence_labels, sentence_outputs, from_logits=True)) sentence_loss = tf.reduce_mean(
metrics['next_sentence_loss'].update_state(sentence_loss) tf.keras.losses.sparse_categorical_crossentropy(
total_loss = mlm_loss + sentence_loss sentence_labels, sentence_outputs, from_logits=True))
else: metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss total_loss = mlm_loss + sentence_loss
else:
if aux_losses: total_loss = mlm_loss
total_loss += tf.add_n(aux_losses)
return total_loss if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for pretraining.""" """Returns tf.data.Dataset for pretraining."""
...@@ -128,14 +130,15 @@ class MaskedLMTask(base_task.Task): ...@@ -128,14 +130,15 @@ class MaskedLMTask(base_task.Task):
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics]) with tf.name_scope('MaskedLMTask/process_metrics'):
if 'masked_lm_accuracy' in metrics: metrics = dict([(metric.name, metric) for metric in metrics])
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'], if 'masked_lm_accuracy' in metrics:
model_outputs['lm_output'], metrics['masked_lm_accuracy'].update_state(
labels['masked_lm_weights']) labels['masked_lm_ids'], model_outputs['lm_output'],
if 'next_sentence_accuracy' in metrics: labels['masked_lm_weights'])
metrics['next_sentence_accuracy'].update_state( if 'next_sentence_accuracy' in metrics:
labels['next_sentence_labels'], model_outputs['next_sentence']) metrics['next_sentence_accuracy'].update_state(
labels['next_sentence_labels'], model_outputs['next_sentence'])
def train_step(self, inputs, model: tf.keras.Model, def train_step(self, inputs, model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer, metrics): optimizer: tf.keras.optimizers.Optimizer, metrics):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment