Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9691ef7a
Commit
9691ef7a
authored
Oct 16, 2019
by
minoring
Browse files
Add compat.v1 to support TF 2.0 in mnist
parent
06412123
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
18 deletions
+20
-18
official/mnist/dataset.py
official/mnist/dataset.py
+8
-8
official/mnist/mnist.py
official/mnist/mnist.py
+9
-7
official/mnist/mnist_test.py
official/mnist/mnist_test.py
+3
-3
No files found.
official/mnist/dataset.py
View file @
9691ef7a
...
...
@@ -35,7 +35,7 @@ def read32(bytestream):
def
check_image_file_header
(
filename
):
"""Validate that filename corresponds to images for the MNIST dataset."""
with
tf
.
gfile
.
Open
(
filename
,
'rb'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
filename
,
'rb'
)
as
f
:
magic
=
read32
(
f
)
read32
(
f
)
# num_images, unused
rows
=
read32
(
f
)
...
...
@@ -51,7 +51,7 @@ def check_image_file_header(filename):
def
check_labels_file_header
(
filename
):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with
tf
.
gfile
.
Open
(
filename
,
'rb'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
filename
,
'rb'
)
as
f
:
magic
=
read32
(
f
)
read32
(
f
)
# num_items, unused
if
magic
!=
2049
:
...
...
@@ -62,17 +62,17 @@ def check_labels_file_header(filename):
def
download
(
directory
,
filename
):
"""Download (and unzip) a file from the MNIST dataset if not already done."""
filepath
=
os
.
path
.
join
(
directory
,
filename
)
if
tf
.
gfile
.
E
xists
(
filepath
):
if
tf
.
io
.
gfile
.
e
xists
(
filepath
):
return
filepath
if
not
tf
.
gfile
.
E
xists
(
directory
):
tf
.
gfile
.
MakeD
ir
s
(
directory
)
if
not
tf
.
io
.
gfile
.
e
xists
(
directory
):
tf
.
io
.
gfile
.
mkd
ir
(
directory
)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url
=
'https://storage.googleapis.com/cvdf-datasets/mnist/'
+
filename
+
'.gz'
_
,
zipped_filepath
=
tempfile
.
mkstemp
(
suffix
=
'.gz'
)
print
(
'Downloading %s to %s'
%
(
url
,
zipped_filepath
))
urllib
.
request
.
urlretrieve
(
url
,
zipped_filepath
)
with
gzip
.
open
(
zipped_filepath
,
'rb'
)
as
f_in
,
\
tf
.
gfile
.
Open
(
filepath
,
'wb'
)
as
f_out
:
tf
.
io
.
gfile
.
GFile
(
filepath
,
'wb'
)
as
f_out
:
shutil
.
copyfileobj
(
f_in
,
f_out
)
os
.
remove
(
zipped_filepath
)
return
filepath
...
...
@@ -89,13 +89,13 @@ def dataset(directory, images_file, labels_file):
def
decode_image
(
image
):
# Normalize from [0, 255] to [0.0, 1.0]
image
=
tf
.
decode_raw
(
image
,
tf
.
uint8
)
image
=
tf
.
io
.
decode_raw
(
image
,
tf
.
uint8
)
image
=
tf
.
cast
(
image
,
tf
.
float32
)
image
=
tf
.
reshape
(
image
,
[
784
])
return
image
/
255.0
def
decode_label
(
label
):
label
=
tf
.
decode_raw
(
label
,
tf
.
uint8
)
# tf.string -> [tf.uint8]
label
=
tf
.
io
.
decode_raw
(
label
,
tf
.
uint8
)
# tf.string -> [tf.uint8]
label
=
tf
.
reshape
(
label
,
[])
# label is a scalar
return
tf
.
cast
(
label
,
tf
.
int32
)
...
...
official/mnist/mnist.py
View file @
9691ef7a
...
...
@@ -125,11 +125,12 @@ def model_fn(features, labels, mode, params):
'classify'
:
tf
.
estimator
.
export
.
PredictOutput
(
predictions
)
})
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
LEARNING_RATE
)
optimizer
=
tf
.
compat
.
v1
.
train
.
AdamOptimizer
(
learning_rate
=
LEARNING_RATE
)
logits
=
model
(
image
,
training
=
True
)
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
accuracy
=
tf
.
metrics
.
accuracy
(
loss
=
tf
.
compat
.
v1
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
accuracy
=
tf
.
compat
.
v1
.
metrics
.
accuracy
(
labels
=
labels
,
predictions
=
tf
.
argmax
(
logits
,
axis
=
1
))
# Name tensors to be logged with LoggingTensorHook.
...
...
@@ -143,7 +144,8 @@ def model_fn(features, labels, mode, params):
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
loss
=
loss
,
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
train
.
get_or_create_global_step
()))
train_op
=
optimizer
.
minimize
(
loss
,
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()))
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
logits
=
model
(
image
,
training
=
False
)
loss
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
labels
=
labels
,
logits
=
logits
)
...
...
@@ -166,7 +168,7 @@ def run_mnist(flags_obj):
model_helpers
.
apply_clean
(
flags_obj
)
model_function
=
model_fn
session_config
=
tf
.
ConfigProto
(
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
inter_op_parallelism_threads
=
flags_obj
.
inter_op_parallelism_threads
,
intra_op_parallelism_threads
=
flags_obj
.
intra_op_parallelism_threads
,
allow_soft_placement
=
True
)
...
...
@@ -227,7 +229,7 @@ def run_mnist(flags_obj):
# Export the model
if
flags_obj
.
export_dir
is
not
None
:
image
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
image
=
tf
.
compat
.
v1
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
input_fn
=
tf
.
estimator
.
export
.
build_raw_serving_input_receiver_fn
({
'image'
:
image
,
})
...
...
@@ -240,6 +242,6 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
define_mnist_flags
()
absl_app
.
run
(
main
)
official/mnist/mnist_test.py
View file @
9691ef7a
...
...
@@ -29,8 +29,8 @@ BATCH_SIZE = 100
def
dummy_input_fn
():
image
=
tf
.
random
_
uniform
([
BATCH_SIZE
,
784
])
labels
=
tf
.
random
_
uniform
([
BATCH_SIZE
,
1
],
maxval
=
9
,
dtype
=
tf
.
int32
)
image
=
tf
.
random
.
uniform
([
BATCH_SIZE
,
784
])
labels
=
tf
.
random
.
uniform
([
BATCH_SIZE
,
1
],
maxval
=
9
,
dtype
=
tf
.
int32
)
return
image
,
labels
...
...
@@ -64,7 +64,7 @@ class Tests(tf.test.TestCase):
self
.
assertEqual
(
2
,
global_step
)
self
.
assertEqual
(
accuracy
.
shape
,
())
input_fn
=
lambda
:
tf
.
random
_
uniform
([
3
,
784
])
input_fn
=
lambda
:
tf
.
random
.
uniform
([
3
,
784
])
predictions_generator
=
classifier
.
predict
(
input_fn
)
for
_
in
range
(
3
):
predictions
=
next
(
predictions_generator
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment