Commit e48a403e authored by Asim Shankar's avatar Asim Shankar
Browse files

official/mnist: Use tf.keras.Sequential to simplify network definition.

parent aad56e4c
...@@ -29,7 +29,7 @@ from official.utils.logs import hooks_helper ...@@ -29,7 +29,7 @@ from official.utils.logs import hooks_helper
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
class Model(tf.keras.Model): def create_model(data_format):
"""Model to recognize digits in the MNIST dataset. """Model to recognize digits in the MNIST dataset.
Network structure is equivalent to: Network structure is equivalent to:
...@@ -37,60 +37,41 @@ class Model(tf.keras.Model): ...@@ -37,60 +37,41 @@ class Model(tf.keras.Model):
and and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written as a tf.keras.Model using the tf.layers API. But uses the tf.keras API.
"""
def __init__(self, data_format):
"""Creates a model for classifying a hand-written digit.
Args: Args:
data_format: Either 'channels_first' or 'channels_last'. data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is 'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
""" """
super(Model, self).__init__() input_shape = None
if data_format == 'channels_first': if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28] input_shape = [1, 28, 28]
else: else:
assert data_format == 'channels_last' assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1] input_shape = [28, 28, 1]
self.conv1 = tf.layers.Conv2D( L = tf.keras.layers
32, 5, padding='same', data_format=data_format, activation=tf.nn.relu) max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
self.conv2 = tf.layers.Conv2D( return tf.keras.Sequential([
64, 5, padding='same', data_format=data_format, activation=tf.nn.relu) L.Reshape(input_shape),
self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) L.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
self.fc2 = tf.layers.Dense(10) max_pool,
self.dropout = tf.layers.Dropout(0.4) L.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
self.max_pool2d = tf.layers.MaxPooling2D( max_pool,
(2, 2), (2, 2), padding='same', data_format=data_format) L.Flatten(),
L.Dense(1024, activation=tf.nn.relu),
def __call__(self, inputs, training): L.Dropout(0.4),
"""Add operations to classify a batch of input images. L.Dense(10)])
Args:
inputs: A Tensor representing a batch of input images.
training: A boolean. Set to True to add operations required only when
training the classifier.
Returns:
A logits Tensor with shape [<batch_size>, 10].
"""
y = tf.reshape(inputs, self._input_shape)
y = self.conv1(y)
y = self.max_pool2d(y)
y = self.conv2(y)
y = self.max_pool2d(y)
y = tf.layers.flatten(y)
y = self.fc1(y)
y = self.dropout(y, training=training)
return self.fc2(y)
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = Model(params['data_format']) model = create_model(params['data_format'])
image = features image = features
if isinstance(image, dict): if isinstance(image, dict):
image = features['image'] image = features['image']
......
...@@ -116,7 +116,7 @@ def main(argv): ...@@ -116,7 +116,7 @@ def main(argv):
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size) test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.Model(data_format) model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum) optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
......
...@@ -40,7 +40,7 @@ def random_dataset(): ...@@ -40,7 +40,7 @@ def random_dataset():
def train(defun=False): def train(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
...@@ -51,7 +51,7 @@ def train(defun=False): ...@@ -51,7 +51,7 @@ def train(defun=False):
def evaluate(defun=False): def evaluate(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
dataset = random_dataset() dataset = random_dataset()
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
......
...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params): ...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params):
if isinstance(image, dict): if isinstance(image, dict):
image = features["image"] image = features["image"]
model = mnist.Model("channels_last") model = mnist.create_model("channels_last")
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment