Commit 7acce35d authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Fix bug in test dataset generation

PiperOrigin-RevId: 273066504
parent f8d9c9b8
...@@ -128,13 +128,14 @@ def get_input_dataset(flags_obj, strategy): ...@@ -128,13 +128,14 @@ def get_input_dataset(flags_obj, strategy):
input_context=ctx) input_context=ctx)
return test_ds return test_ds
if strategy: if strategy:
if isinstance(strategy, tf.distribute.experimental.TPUStrategy): if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
test_ds = strategy.experimental_distribute_datasets_from_function(_test_data_fn) test_ds = strategy.experimental_distribute_datasets_from_function(
_test_data_fn)
else:
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
else: else:
test_ds = strategy.experimental_distribute_dataset(_test_data_fn()) test_ds = _test_data_fn()
else:
test_ds = _test_data_fn()
return train_ds, test_ds return train_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment