Commit 0734276a authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Fix mnist_test.py

PiperOrigin-RevId: 278024052
parent ffa522ea
...@@ -69,11 +69,13 @@ def decode_image(example, feature): ...@@ -69,11 +69,13 @@ def decode_image(example, feature):
return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255 return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
def run(flags_obj, strategy_override=None): def run(flags_obj, datasets_override=None, strategy_override=None):
"""Run MNIST model training and eval loop using native Keras APIs. """Run MNIST model training and eval loop using native Keras APIs.
Args: Args:
flags_obj: An object containing parsed flag values. flags_obj: An object containing parsed flag values.
datasets_override: A pair of `tf.data.Dataset` objects to train the model,
representing the train and test sets.
strategy_override: A `tf.distribute.Strategy` object to use for model. strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns: Returns:
...@@ -90,7 +92,7 @@ def run(flags_obj, strategy_override=None): ...@@ -90,7 +92,7 @@ def run(flags_obj, strategy_override=None):
if flags_obj.download: if flags_obj.download:
mnist.download_and_prepare() mnist.download_and_prepare()
mnist_train, mnist_test = mnist.as_dataset( mnist_train, mnist_test = datasets_override or mnist.as_dataset(
split=['train', 'test'], split=['train', 'test'],
decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
as_supervised=True) as_supervised=True)
......
...@@ -67,16 +67,19 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -67,16 +67,19 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"--data_dir=" "--data_dir="
] ]
def _mock_dataset(self, *args, **kwargs): # pylint: disable=unused-argument dummy_data = (
"""Generate mock dataset with TPU-compatible dtype (instead of uint8).""" tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
return tf.data.Dataset.from_tensor_slices({ tf.range(10),
"image": tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32), )
"label": tf.range(10), datasets = (
}) tf.data.Dataset.from_tensor_slices(dummy_data),
tf.data.Dataset.from_tensor_slices(dummy_data),
)
run = functools.partial(mnist_main.run, strategy_override=distribution) run = functools.partial(mnist_main.run,
datasets_override=datasets,
strategy_override=distribution)
with tfds.testing.mock_data(as_dataset_fn=_mock_dataset):
integration.run_synthetic( integration.run_synthetic(
main=run, main=run,
synth=False, synth=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment