Commit c83f3256 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326068206
parent 878f6caf
...@@ -27,6 +27,7 @@ import tensorflow_hub as hub ...@@ -27,6 +27,7 @@ import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
...@@ -103,7 +104,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -103,7 +104,7 @@ class SentencePredictionTask(base_task.Task):
if aux_losses: if aux_losses:
loss += tf.add_n(aux_losses) loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss) return tf_utils.safe_mean(loss)
def build_inputs(self, params, input_context=None): def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task.""" """Returns tf.data.Dataset for sentence_prediction task."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment