"docs/vscode:/vscode.git/clone" did not exist on "b2323aa2b76ffa90a71507f09c18792d1dba2523"
Commit 4d79fee3 authored by Asim Shankar's avatar Asim Shankar
Browse files

[mnist]: Address Neal's comment

parent 49997c1f
......@@ -62,6 +62,41 @@ class Tests(tf.test.TestCase):
self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ())
def mnist_model_fn_helper(self, mode):
features, labels = dummy_input_fn()
image_count = features.shape[0]
spec = mnist.model_fn(features, labels, mode, {
'data_format': 'channels_last'
})
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (image_count,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_mnist_model_fn_train_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_mnist_model_fn_eval_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.EVAL)
def test_mnist_model_fn_predict_mode(self):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
class Benchmarks(tf.test.Benchmark):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment