Commit 6e5057bd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320536301
parent b9160cc6
...@@ -62,10 +62,10 @@ class MaskedLMTask(base_task.Task): ...@@ -62,10 +62,10 @@ class MaskedLMTask(base_task.Task):
sentence_labels = labels['next_sentence_labels'] sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast( sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32) model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy( sentence_loss = tf.reduce_mean(
sentence_labels, tf.keras.losses.sparse_categorical_crossentropy(sentence_labels,
sentence_outputs, sentence_outputs,
from_logits=True) from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
...@@ -91,7 +91,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -91,7 +91,7 @@ class SentencePredictionTask(base_task.Task):
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
return loss return tf.reduce_mean(loss)
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment