"...bert-large_oneflow.git" did not exist on "5988d2cc317ac8cb8e21f84ec17dbd59e805df6c"
Commit d1718519 authored by Asim Shankar's avatar Asim Shankar
Browse files

official/mnist: Linter fixes

parent 4e0ca759
...@@ -41,33 +41,43 @@ def create_model(data_format): ...@@ -41,33 +41,43 @@ def create_model(data_format):
But uses the tf.keras API. But uses the tf.keras API.
Args: Args:
data_format: Either 'channels_first' or 'channels_last'. data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
'channels_first' is typically faster on GPUs while 'channels_last' is typically faster on GPUs while 'channels_last' is typically faster on
typically faster on CPUs. See CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats https://www.tensorflow.org/performance/performance_guide#data_formats
Returns: Returns:
A tf.keras.Model. A tf.keras.Model.
""" """
input_shape = None
if data_format == 'channels_first': if data_format == 'channels_first':
input_shape = [1, 28, 28] input_shape = [1, 28, 28]
else: else:
assert data_format == 'channels_last' assert data_format == 'channels_last'
input_shape = [28, 28, 1] input_shape = [28, 28, 1]
L = tf.keras.layers l = tf.keras.layers
max_pool = L.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) max_pool = l.MaxPooling2D(
return tf.keras.Sequential([ (2, 2), (2, 2), padding='same', data_format=data_format)
L.Reshape(input_shape), return tf.keras.Sequential(
L.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu), [
max_pool, l.Reshape(input_shape),
L.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu), l.Conv2D(
max_pool, 32,
L.Flatten(), 5,
L.Dense(1024, activation=tf.nn.relu), padding='same',
L.Dropout(0.4), data_format=data_format,
L.Dense(10)]) activation=tf.nn.relu), max_pool,
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=tf.nn.relu), max_pool,
l.Flatten(),
l.Dense(1024, activation=tf.nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
...@@ -122,8 +132,7 @@ def model_fn(features, labels, mode, params): ...@@ -122,8 +132,7 @@ def model_fn(features, labels, mode, params):
eval_metric_ops={ eval_metric_ops={
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=labels, labels=labels, predictions=tf.argmax(logits, axis=1)),
predictions=tf.argmax(logits, axis=1)),
}) })
...@@ -213,8 +222,8 @@ def main(argv): ...@@ -213,8 +222,8 @@ def main(argv):
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold( if model_helpers.past_stop_threshold(flags.stop_threshold,
flags.stop_threshold, eval_results['accuracy']): eval_results['accuracy']):
break break
# Export the model # Export the model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment