Commit 1fef2955 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Address comments from PR

https://github.com/tensorflow/models/pull/3024
parent 1a83863d
...@@ -42,7 +42,7 @@ The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_ ...@@ -42,7 +42,7 @@ The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_
Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel. Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
``` ```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image_raw=examples.npy saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image=examples.npy
``` ```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1. `examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
......
...@@ -126,10 +126,14 @@ class Model(object): ...@@ -126,10 +126,14 @@ class Model(object):
return self.fc2(y) return self.fc2(y)
def predict_spec(model, image): def model_fn(features, labels, mode, params):
"""EstimatorSpec for predictions.""" """The model_fn argument for creating an Estimator."""
model = Model(params['data_format'])
image = features
if isinstance(image, dict): if isinstance(image, dict):
image = image['image'] image = features['image']
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False) logits = model(image, training=False)
predictions = { predictions = {
'classes': tf.argmax(logits, axis=1), 'classes': tf.argmax(logits, axis=1),
...@@ -141,10 +145,7 @@ def predict_spec(model, image): ...@@ -141,10 +145,7 @@ def predict_spec(model, image):
export_outputs={ export_outputs={
'classify': tf.estimator.export.PredictOutput(predictions) 'classify': tf.estimator.export.PredictOutput(predictions)
}) })
if mode == tf.estimator.ModeKeys.TRAIN:
def train_spec(model, image, labels):
"""EstimatorSpec for training."""
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
logits = model(image, training=True) logits = model(image, training=True)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
...@@ -158,10 +159,7 @@ def train_spec(model, image, labels): ...@@ -158,10 +159,7 @@ def train_spec(model, image, labels):
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
loss=loss, loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())) train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
def eval_spec(model, image, labels):
"""EstimatorSpec for evaluation."""
logits = model(image, training=False) logits = model(image, training=False)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
...@@ -175,17 +173,6 @@ def eval_spec(model, image, labels): ...@@ -175,17 +173,6 @@ def eval_spec(model, image, labels):
}) })
def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator."""
model = Model(params['data_format'])
if mode == tf.estimator.ModeKeys.PREDICT:
return predict_spec(model, features)
if mode == tf.estimator.ModeKeys.TRAIN:
return train_spec(model, features, labels)
if mode == tf.estimator.ModeKeys.EVAL:
return eval_spec(model, features, labels)
def main(unused_argv): def main(unused_argv):
data_format = FLAGS.data_format data_format = FLAGS.data_format
if data_format is None: if data_format is None:
...@@ -227,7 +214,7 @@ def main(unused_argv): ...@@ -227,7 +214,7 @@ def main(unused_argv):
if FLAGS.export_dir is not None: if FLAGS.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': tf.placeholder(tf.float32, [None, 28, 28]) 'image': image,
}) })
mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn) mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment