Commit d1718519 authored by Asim Shankar's avatar Asim Shankar
Browse files

official/mnist: Linter fixes

parent 4e0ca759
......@@ -41,33 +41,43 @@ def create_model(data_format):
But uses the tf.keras API.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
Returns:
A tf.keras.Model.
"""
input_shape = None
if data_format == 'channels_first':
input_shape = [1, 28, 28]
else:
assert data_format == 'channels_last'
input_shape = [28, 28, 1]
L = tf.keras.layers
max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)
return tf.keras.Sequential([
L.Reshape(input_shape),
L.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
max_pool,
L.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu),
max_pool,
L.Flatten(),
L.Dense(1024, activation=tf.nn.relu),
L.Dropout(0.4),
L.Dense(10)])
l = tf.keras.layers
max_pool = l.MaxPooling2D(
(2, 2), (2, 2), padding='same', data_format=data_format)
return tf.keras.Sequential(
[
l.Reshape(input_shape),
l.Conv2D(
32,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu), max_pool,
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu), max_pool,
l.Flatten(),
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
def model_fn(features, labels, mode, params):
......@@ -122,8 +132,7 @@ def model_fn(features, labels, mode, params):
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=labels,
predictions=tf.argmax(logits, axis=1)),
labels=labels, predictions=tf.argmax(logits, axis=1)),
})
......@@ -213,8 +222,8 @@ def main(argv):
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(
flags.stop_threshold, eval_results['accuracy']):
if model_helpers.past_stop_threshold(flags.stop_threshold,
eval_results['accuracy']):
break
# Export the model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment