Commit 52188963 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326068206
parent 6df2c663
......@@ -27,6 +27,7 @@ import tensorflow_hub as hub
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
......@@ -103,7 +104,7 @@ class SentencePredictionTask(base_task.Task):
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
return tf_utils.safe_mean(loss)
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment