Commit 5a1faffd authored by Changming Sun's avatar Changming Sun Committed by k-w-w
Browse files

official/mnist: support savedmodel (#2967)

With examples, and updates to the README
parent f40184cc
......@@ -20,3 +20,36 @@ python mnist.py
The model will begin training and will automatically evaluate itself on the
validation data.
## Exporting the model
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
```
python mnist.py --export_dir /tmp/mnist_saved_model
```
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
**Getting predictions with SavedModel**
Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image_raw=examples.npy
```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
The output should look similar to below:
```
Result for output key classes:
[5 3]
Result for output key probabilities:
[[ 1.53558474e-07 1.95694142e-13 1.31193523e-09 5.47467265e-03
5.85711526e-22 9.94520664e-01 3.48423509e-06 2.65365645e-17
9.78631419e-07 3.15522470e-08]
[ 1.22413359e-04 5.87615965e-08 1.72251271e-06 9.39960718e-01
3.30306928e-11 2.87386645e-02 2.82353517e-02 8.21146413e-18
2.52568233e-03 4.15460236e-04]]
```
......@@ -58,6 +58,10 @@ parser.add_argument(
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
parser.add_argument(
'--export_dir',
type=str,
help='The directory where the exported SavedModel will be stored.')
def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
......@@ -152,6 +156,9 @@ def mnist_model(inputs, mode, data_format):
def mnist_model_fn(features, labels, mode, params):
"""Model function for MNIST."""
if mode == tf.estimator.ModeKeys.PREDICT and isinstance(features,dict):
features = features['image_raw']
logits = mnist_model(features, mode, params['data_format'])
predictions = {
......@@ -160,7 +167,9 @@ def mnist_model_fn(features, labels, mode, params):
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
export_outputs={'classify': tf.estimator.export.PredictOutput(predictions)}
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions,
export_outputs=export_outputs)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
......@@ -222,6 +231,13 @@ def main(unused_argv):
print()
print('Evaluation results:\n\t%s' % eval_results)
# Export the model
if FLAGS.export_dir is not None:
image = tf.placeholder(tf.float32,[None, 28, 28])
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
{"image_raw":image})
mnist_classifier.export_savedmodel(FLAGS.export_dir, serving_input_fn)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment