Unverified Commit 2c74e416 authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #3942 from asimshankar/mnist_sequential_and_estimator

official/mnist: Use tf.keras.Sequential
parents 720d3363 cfe945ef
...@@ -30,7 +30,7 @@ from official.utils.misc import model_helpers ...@@ -30,7 +30,7 @@ from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
class Model(tf.keras.Model): def create_model(data_format):
"""Model to recognize digits in the MNIST dataset. """Model to recognize digits in the MNIST dataset.
Network structure is equivalent to: Network structure is equivalent to:
...@@ -38,60 +38,55 @@ class Model(tf.keras.Model): ...@@ -38,60 +38,55 @@ class Model(tf.keras.Model):
and and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written as a tf.keras.Model using the tf.layers API. But uses the tf.keras API.
"""
Args:
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
def __init__(self, data_format): Returns:
"""Creates a model for classifying a hand-written digit. A tf.keras.Model.
"""
Args: if data_format == 'channels_first':
data_format: Either 'channels_first' or 'channels_last'. input_shape = [1, 28, 28]
'channels_first' is typically faster on GPUs while 'channels_last' is else:
typically faster on CPUs. See assert data_format == 'channels_last'
https://www.tensorflow.org/performance/performance_guide#data_formats input_shape = [28, 28, 1]
"""
super(Model, self).__init__() l = tf.keras.layers
if data_format == 'channels_first': max_pool = l.MaxPooling2D(
self._input_shape = [-1, 1, 28, 28] (2, 2), (2, 2), padding='same', data_format=data_format)
else: # The model consists of a sequential chain of layers, so tf.keras.Sequential
assert data_format == 'channels_last' # (a subclass of tf.keras.Model) makes for a compact description.
self._input_shape = [-1, 28, 28, 1] return tf.keras.Sequential(
[
self.conv1 = tf.layers.Conv2D( l.Reshape(input_shape),
32, 5, padding='same', data_format=data_format, activation=tf.nn.relu) l.Conv2D(
self.conv2 = tf.layers.Conv2D( 32,
64, 5, padding='same', data_format=data_format, activation=tf.nn.relu) 5,
self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) padding='same',
self.fc2 = tf.layers.Dense(10) data_format=data_format,
self.dropout = tf.layers.Dropout(0.4) activation=tf.nn.relu),
self.max_pool2d = tf.layers.MaxPooling2D( max_pool,
(2, 2), (2, 2), padding='same', data_format=data_format) l.Conv2D(
64,
def __call__(self, inputs, training): 5,
"""Add operations to classify a batch of input images. padding='same',
data_format=data_format,
Args: activation=tf.nn.relu),
inputs: A Tensor representing a batch of input images. max_pool,
training: A boolean. Set to True to add operations required only when l.Flatten(),
training the classifier. l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
Returns: l.Dense(10)
A logits Tensor with shape [<batch_size>, 10]. ])
"""
y = tf.reshape(inputs, self._input_shape)
y = self.conv1(y)
y = self.max_pool2d(y)
y = self.conv2(y)
y = self.max_pool2d(y)
y = tf.layers.flatten(y)
y = self.fc1(y)
y = self.dropout(y, training=training)
return self.fc2(y)
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = Model(params['data_format']) model = create_model(params['data_format'])
image = features image = features
if isinstance(image, dict): if isinstance(image, dict):
image = features['image'] image = features['image']
...@@ -141,8 +136,7 @@ def model_fn(features, labels, mode, params): ...@@ -141,8 +136,7 @@ def model_fn(features, labels, mode, params):
eval_metric_ops={ eval_metric_ops={
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=labels, labels=labels, predictions=tf.argmax(logits, axis=1)),
predictions=tf.argmax(logits, axis=1)),
}) })
...@@ -232,8 +226,8 @@ def main(argv): ...@@ -232,8 +226,8 @@ def main(argv):
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold( if model_helpers.past_stop_threshold(flags.stop_threshold,
flags.stop_threshold, eval_results['accuracy']): eval_results['accuracy']):
break break
# Export the model # Export the model
......
...@@ -116,7 +116,7 @@ def main(argv): ...@@ -116,7 +116,7 @@ def main(argv):
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size) test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.Model(data_format) model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum) optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
......
...@@ -40,7 +40,7 @@ def random_dataset(): ...@@ -40,7 +40,7 @@ def random_dataset():
def train(defun=False): def train(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
...@@ -51,7 +51,7 @@ def train(defun=False): ...@@ -51,7 +51,7 @@ def train(defun=False):
def evaluate(defun=False): def evaluate(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
dataset = random_dataset() dataset = random_dataset()
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
......
...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params): ...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params):
if isinstance(image, dict): if isinstance(image, dict):
image = features["image"] image = features["image"]
model = mnist.Model("channels_last") model = mnist.create_model("channels_last")
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment