Commit dc9c75dd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix a unit test: nlp.tasks.sentence_prediction_test.py

PiperOrigin-RevId: 315778170
parent 5a3af75c
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for official.nlp.tasks.sentence_prediction.""" """Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os import os
import orbit
# pylint: disable=g-bad-import-order
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -34,8 +33,8 @@ class SentencePredictionTaskTest(tf.test.TestCase): ...@@ -34,8 +33,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
metrics = task.build_metrics() metrics = task.build_metrics()
strategy = tf.distribute.get_strategy() strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs, dataset = strategy.experimental_distribute_datasets_from_function(
config.train_data) functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset) iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1) optimizer = tf.keras.optimizers.SGD(lr=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment