Commit 1fef2955 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Address comments from PR

https://github.com/tensorflow/models/pull/3024
parent 1a83863d
...@@ -33,7 +33,7 @@ python mnist_test.py --benchmarks=. ...@@ -33,7 +33,7 @@ python mnist_test.py --benchmarks=.
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`: You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
``` ```
python mnist.py --export_dir /tmp/mnist_saved_model python mnist.py --export_dir /tmp/mnist_saved_model
``` ```
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`). The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
...@@ -42,7 +42,7 @@ The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_ ...@@ -42,7 +42,7 @@ The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_
Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel. Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
``` ```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image_raw=examples.npy saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image=examples.npy
``` ```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1. `examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
......
...@@ -126,64 +126,51 @@ class Model(object): ...@@ -126,64 +126,51 @@ class Model(object):
return self.fc2(y) return self.fc2(y)
def predict_spec(model, image):
"""EstimatorSpec for predictions."""
if isinstance(image, dict):
image = image['image']
logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
def train_spec(model, image, labels):
"""EstimatorSpec for training."""
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
logits = model(image, training=True)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
# Name the accuracy tensor 'train_accuracy' to demonstrate the
# LoggingTensorHook.
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
def eval_spec(model, image, labels):
"""EstimatorSpec for evaluation."""
logits = model(image, training=False)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1),
predictions=tf.argmax(logits, axis=1)),
})
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = Model(params['data_format']) model = Model(params['data_format'])
image = features
if isinstance(image, dict):
image = features['image']
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf.estimator.ModeKeys.PREDICT:
return predict_spec(model, features) logits = model(image, training=False)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits),
}
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions)
})
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
return train_spec(model, features, labels) optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
logits = model(image, training=True)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
# Name the accuracy tensor 'train_accuracy' to demonstrate the
# LoggingTensorHook.
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL: if mode == tf.estimator.ModeKeys.EVAL:
return eval_spec(model, features, labels) logits = model(image, training=False)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops={
'accuracy':
tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1),
predictions=tf.argmax(logits, axis=1)),
})
def main(unused_argv): def main(unused_argv):
...@@ -227,7 +214,7 @@ def main(unused_argv): ...@@ -227,7 +214,7 @@ def main(unused_argv):
if FLAGS.export_dir is not None: if FLAGS.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': tf.placeholder(tf.float32, [None, 28, 28]) 'image': image,
}) })
mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn) mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment