Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
223e0f3a
Unverified
Commit
223e0f3a
authored
Oct 01, 2018
by
Asim Shankar
Committed by
GitHub
Oct 01, 2018
Browse files
Merge pull request #5379 from aman2930/master
Enabling prediction in mnist_tpu.
parents
1b6ca655
ea24314d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
2 deletions
+31
-2
official/mnist/mnist_tpu.py
official/mnist/mnist_tpu.py
+31
-2
No files found.
official/mnist/mnist_tpu.py
View file @
223e0f3a
...
...
@@ -65,6 +65,7 @@ tf.flags.DEFINE_integer("eval_steps", 0,
tf
.
flags
.
DEFINE_float
(
"learning_rate"
,
0.05
,
"Learning rate."
)
tf
.
flags
.
DEFINE_bool
(
"use_tpu"
,
True
,
"Use TPUs rather than plain CPUs"
)
tf
.
flags
.
DEFINE_bool
(
"enable_predict"
,
True
,
"Do some predictions at the end"
)
tf
.
flags
.
DEFINE_integer
(
"iterations"
,
50
,
"Number of iterations per TPU training loop."
)
tf
.
flags
.
DEFINE_integer
(
"num_shards"
,
8
,
"Number of shards (TPU chips)."
)
...
...
@@ -82,13 +83,20 @@ def model_fn(features, labels, mode, params):
"""model_fn constructs the ML model used to predict handwritten digits."""
del
params
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
raise
RuntimeError
(
"mode {} is not supported yet"
.
format
(
mode
))
image
=
features
if
isinstance
(
image
,
dict
):
image
=
features
[
"image"
]
model
=
mnist
.
create_model
(
"channels_last"
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
logits
=
model
(
image
,
training
=
False
)
predictions
=
{
'class_ids'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
),
}
return
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
,
predictions
=
predictions
)
logits
=
model
(
image
,
training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
...
...
@@ -134,6 +142,14 @@ def eval_input_fn(params):
return
images
,
labels
def
predict_input_fn
(
params
):
batch_size
=
params
[
"batch_size"
]
data_dir
=
params
[
"data_dir"
]
# Take out top 10 samples from test data to make the predictions.
ds
=
dataset
.
test
(
data_dir
).
take
(
10
).
batch
(
batch_size
)
return
ds
def
main
(
argv
):
del
argv
# Unused.
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
...
...
@@ -157,6 +173,7 @@ def main(argv):
use_tpu
=
FLAGS
.
use_tpu
,
train_batch_size
=
FLAGS
.
batch_size
,
eval_batch_size
=
FLAGS
.
batch_size
,
predict_batch_size
=
FLAGS
.
batch_size
,
params
=
{
"data_dir"
:
FLAGS
.
data_dir
},
config
=
run_config
)
# TPUEstimator.train *requires* a max_steps argument.
...
...
@@ -168,6 +185,18 @@ def main(argv):
if
FLAGS
.
eval_steps
:
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
FLAGS
.
eval_steps
)
# Run prediction on top few samples of test data.
if
FLAGS
.
enable_predict
:
predictions
=
estimator
.
predict
(
input_fn
=
predict_input_fn
)
for
pred_dict
in
predictions
:
template
=
(
'Prediction is "{}" ({:.1f}%).'
)
class_id
=
pred_dict
[
'class_ids'
]
probability
=
pred_dict
[
'probabilities'
][
class_id
]
print
(
template
.
format
(
class_id
,
100
*
probability
))
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment