Commit e48a403e authored by Asim Shankar's avatar Asim Shankar
Browse files

official/mnist: Use tf.keras.Sequential to simplify network definition.

parent aad56e4c
......@@ -29,7 +29,7 @@ from official.utils.logs import hooks_helper
LEARNING_RATE = 1e-4
class Model(tf.keras.Model):
def create_model(data_format):
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
......@@ -37,60 +37,41 @@ class Model(tf.keras.Model):
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written as a tf.keras.Model using the tf.layers API.
"""
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
def __init__(self, data_format):
"""Creates a model for classifying a hand-written digit.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(Model, self).__init__()
if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28]
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
self.conv1 = tf.layers.Conv2D(
32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
self.conv2 = tf.layers.Conv2D(
64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
self.fc2 = tf.layers.Dense(10)
self.dropout = tf.layers.Dropout(0.4)
self.max_pool2d = tf.layers.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
def __call__(self, inputs, training):
"""Add operations to classify a batch of input images.
Args:
inputs: A Tensor representing a batch of input images.
training: A boolean. Set to True to add operations required only when
training the classifier.
Returns:
A logits Tensor with shape [<batch_size>, 10].
"""
y = tf.reshape(inputs, self._input_shape)
y = self.conv1(y)
y = self.max_pool2d(y)
y = self.conv2(y)
y = self.max_pool2d(y)
y = tf.layers.flatten(y)
y = self.fc1(y)
y = self.dropout(y, training=training)
return self.fc2(y)
Returns:
A tf.keras.Model.
"""
input_shape = None
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1]
L = tf.keras.layers
max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
return tf.keras.Sequential([
L.Reshape(input_shape),
L.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
max_pool,
L.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
max_pool,
L.Flatten(),
L.Dense(1024, activation=tf.nn.relu),
L.Dropout(0.4),
L.Dense(10)])
def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
model = Model(params['data_format'])
model = create_model(params['data_format'])
image = features
if isinstance(image, dict):
image = features['image']
......
......@@ -116,7 +116,7 @@ def main(argv):
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer
model = mnist.Model(data_format)
model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries.
......
......@@ -40,7 +40,7 @@ def random_dataset():
def train(defun=False):
model = mnist.Model(data_format())
model = mnist.create_model(data_format())
if defun:
model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
......@@ -51,7 +51,7 @@ def train(defun=False):
def evaluate(defun=False):
model = mnist.Model(data_format())
model = mnist.create_model(data_format())
dataset = random_dataset()
if defun:
model.call = tfe.defun(model.call)
......
......@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params):
if isinstance(image, dict):
image = features["image"]
model = mnist.Model("channels_last")
model = mnist.create_model("channels_last")
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment