"tests/vscode:/vscode.git/clone" did not exist on "f064b3bf73e479051ed4255d98afad4259a6f012"
Commit 6e5057bd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320536301
parent b9160cc6
......@@ -62,10 +62,10 @@ class MaskedLMTask(base_task.Task):
sentence_labels = labels['next_sentence_labels']
sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels,
sentence_outputs,
from_logits=True)
sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(sentence_labels,
sentence_outputs,
from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss
else:
......
......@@ -91,7 +91,7 @@ class SentencePredictionTask(base_task.Task):
if aux_losses:
loss += tf.add_n(aux_losses)
return loss
return tf.reduce_mean(loss)
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment