Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4d79fee3
"docs/vscode:/vscode.git/clone" did not exist on "b2323aa2b76ffa90a71507f09c18792d1dba2523"
Commit
4d79fee3
authored
Dec 20, 2017
by
Asim Shankar
Browse files
[mnist]: Address Neal's comment
parent
49997c1f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
0 deletions
+35
-0
official/mnist/mnist_test.py
official/mnist/mnist_test.py
+35
-0
No files found.
official/mnist/mnist_test.py
View file @
4d79fee3
...
...
@@ -62,6 +62,41 @@ class Tests(tf.test.TestCase):
self
.
assertEqual
(
predictions
[
'probabilities'
].
shape
,
(
10
,))
self
.
assertEqual
(
predictions
[
'classes'
].
shape
,
())
def
mnist_model_fn_helper
(
self
,
mode
):
features
,
labels
=
dummy_input_fn
()
image_count
=
features
.
shape
[
0
]
spec
=
mnist
.
model_fn
(
features
,
labels
,
mode
,
{
'data_format'
:
'channels_last'
})
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
spec
.
predictions
self
.
assertAllEqual
(
predictions
[
'probabilities'
].
shape
,
(
image_count
,
10
))
self
.
assertEqual
(
predictions
[
'probabilities'
].
dtype
,
tf
.
float32
)
self
.
assertAllEqual
(
predictions
[
'classes'
].
shape
,
(
image_count
,))
self
.
assertEqual
(
predictions
[
'classes'
].
dtype
,
tf
.
int64
)
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
loss
=
spec
.
loss
self
.
assertAllEqual
(
loss
.
shape
,
())
self
.
assertEqual
(
loss
.
dtype
,
tf
.
float32
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
eval_metric_ops
=
spec
.
eval_metric_ops
self
.
assertAllEqual
(
eval_metric_ops
[
'accuracy'
][
0
].
shape
,
())
self
.
assertAllEqual
(
eval_metric_ops
[
'accuracy'
][
1
].
shape
,
())
self
.
assertEqual
(
eval_metric_ops
[
'accuracy'
][
0
].
dtype
,
tf
.
float32
)
self
.
assertEqual
(
eval_metric_ops
[
'accuracy'
][
1
].
dtype
,
tf
.
float32
)
def
test_mnist_model_fn_train_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
TRAIN
)
def
test_mnist_model_fn_eval_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
EVAL
)
def
test_mnist_model_fn_predict_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
PREDICT
)
class
Benchmarks
(
tf
.
test
.
Benchmark
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment