Commit 0734276a authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Fix mnist_test.py

PiperOrigin-RevId: 278024052
parent ffa522ea
......@@ -69,11 +69,13 @@ def decode_image(example, feature):
return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
def run(flags_obj, strategy_override=None):
def run(flags_obj, datasets_override=None, strategy_override=None):
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
datasets_override: A pair of `tf.data.Dataset` objects to train the model,
representing the train and test sets.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
......@@ -90,7 +92,7 @@ def run(flags_obj, strategy_override=None):
if flags_obj.download:
mnist.download_and_prepare()
mnist_train, mnist_test = mnist.as_dataset(
mnist_train, mnist_test = datasets_override or mnist.as_dataset(
split=['train', 'test'],
decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
as_supervised=True)
......
......@@ -67,21 +67,24 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"--data_dir="
]
def _mock_dataset(self, *args, **kwargs): # pylint: disable=unused-argument
"""Generate mock dataset with TPU-compatible dtype (instead of uint8)."""
return tf.data.Dataset.from_tensor_slices({
"image": tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
"label": tf.range(10),
})
run = functools.partial(mnist_main.run, strategy_override=distribution)
with tfds.testing.mock_data(as_dataset_fn=_mock_dataset):
integration.run_synthetic(
main=run,
synth=False,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags)
dummy_data = (
tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
tf.range(10),
)
datasets = (
tf.data.Dataset.from_tensor_slices(dummy_data),
tf.data.Dataset.from_tensor_slices(dummy_data),
)
run = functools.partial(mnist_main.run,
datasets_override=datasets,
strategy_override=distribution)
integration.run_synthetic(
main=run,
synth=False,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment