Commit dc9c75dd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Fix a unit test: nlp.tasks.sentence_prediction_test.py

PiperOrigin-RevId: 315778170
parent 5a3af75c
......@@ -14,9 +14,8 @@
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.sentence_prediction."""
import functools
import os
import orbit
# pylint: disable=g-bad-import-order
import tensorflow as tf
from official.nlp.bert import configs
......@@ -34,8 +33,8 @@ class SentencePredictionTaskTest(tf.test.TestCase):
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.train_data)
dataset = strategy.experimental_distribute_datasets_from_function(
functools.partial(task.build_inputs, config.train_data))
iterator = iter(dataset)
optimizer = tf.keras.optimizers.SGD(lr=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment