Commit 37c12026 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Add 2.x version of MNIST model to model garden.

PiperOrigin-RevId: 277946653
parent 24c619ff
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a simple model on the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
from official.vision.image_classification import common
FLAGS = flags.FLAGS
def build_model():
"""Constructs the ML model used to predict handwritten digits."""
image = tf.keras.layers.Input(shape=(28, 28, 1))
y = tf.keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(image)
y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf.keras.layers.Conv2D(filters=32,
kernel_size=5,
padding='same',
activation='relu')(y)
y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
strides=(2, 2),
padding='same')(y)
y = tf.keras.layers.Flatten()(y)
y = tf.keras.layers.Dense(1024, activation='relu')(y)
y = tf.keras.layers.Dropout(0.4)(y)
probs = tf.keras.layers.Dense(10, activation='softmax')(y)
model = tf.keras.models.Model(image, probs, name='mnist')
return model
@tfds.decode.make_decoder(output_dtype=tf.float32)
def decode_image(example, feature):
"""Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
def run(flags_obj, strategy_override=None):
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
flags_obj: An object containing parsed flag values.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Dictionary of training and eval stats.
"""
strategy = strategy_override or distribution_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
mnist.download_and_prepare()
mnist_train, mnist_test = mnist.as_dataset(
split=['train', 'test'],
decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
as_supervised=True)
train_input_dataset = mnist_train.cache().repeat().shuffle(
buffer_size=50000).batch(flags_obj.batch_size)
eval_input_dataset = mnist_test.cache().repeat().batch(flags_obj.batch_size)
with strategy_scope:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
0.05, decay_steps=100000, decay_rate=0.96)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
model = build_model()
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
num_train_examples = mnist.info.splits['train'].num_examples
train_steps = num_train_examples // flags_obj.batch_size
train_epochs = flags_obj.train_epochs
ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True),
tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
]
num_eval_examples = mnist.info.splits['test'].num_examples
num_eval_steps = num_eval_examples // flags_obj.batch_size
history = model.fit(
train_input_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
callbacks=callbacks,
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
validation_freq=flags_obj.epochs_between_evals)
export_path = os.path.join(flags_obj.model_dir, 'saved_model')
model.save(export_path, include_optimizer=False)
eval_output = model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_mnist_flags():
"""Define command line flags for MNIST model."""
flags_core.define_base(
clean=True,
num_gpu=True,
train_epochs=True,
epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_device()
flags_core.define_distribution()
flags.DEFINE_bool('download', False,
'Whether to download data to `--data_dir`.')
FLAGS.set_default('batch_size', 1024)
def main(_):
model_helpers.apply_clean(FLAGS)
stats = run(flags.FLAGS)
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
define_mnist_flags()
app.run(main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the Keras MNIST model on GPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.utils.misc import keras_utils
from official.utils.testing import integration
from official.vision.image_classification import mnist_main
def eager_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"""Unit tests for sample Keras MNIST model."""
_tempdir = None
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(KerasMnistTest, cls).setUpClass()
mnist_main.define_mnist_flags()
def tearDown(self):
super(KerasMnistTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
@combinations.generate(eager_strategy_combinations())
def test_end_to_end(self, distribution):
"""Test Keras MNIST model with `strategy`."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [
"-train_epochs", "1",
# Let TFDS find the metadata folder automatically
"--data_dir="
]
def _mock_dataset(self, *args, **kwargs): # pylint: disable=unused-argument
"""Generate mock dataset with TPU-compatible dtype (instead of uint8)."""
return tf.data.Dataset.from_tensor_slices({
"image": tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
"label": tf.range(10),
})
run = functools.partial(mnist_main.run, strategy_override=distribution)
with tfds.testing.mock_data(as_dataset_fn=_mock_dataset):
integration.run_synthetic(
main=run,
synth=False,
tmp_root=self.get_temp_dir(),
extra_flags=extra_flags)
if __name__ == "__main__":
tf.compat.v1.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment