Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8e4a1e2e
Unverified
Commit
8e4a1e2e
authored
Jan 02, 2018
by
Asim Shankar
Committed by
GitHub
Jan 02, 2018
Browse files
Merge pull request #3093 from asimshankar/mnist
[mnist]: Use FixedLengthRecordDataset
parents
a3669a93
4a36e31b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
17 deletions
+118
-17
official/mnist/dataset.py
official/mnist/dataset.py
+112
-0
official/mnist/mnist.py
official/mnist/mnist.py
+6
-17
No files found.
official/mnist/dataset.py
0 → 100644
View file @
8e4a1e2e
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
shutil
import
gzip
import
numpy
as
np
from
six.moves
import
urllib
import
tensorflow
as
tf
def
read32
(
bytestream
):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt
=
np
.
dtype
(
np
.
uint32
).
newbyteorder
(
'>'
)
return
np
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
check_image_file_header
(
filename
):
"""Validate that filename corresponds to images for the MNIST dataset."""
with
open
(
filename
)
as
f
:
magic
=
read32
(
f
)
num_images
=
read32
(
f
)
rows
=
read32
(
f
)
cols
=
read32
(
f
)
if
magic
!=
2051
:
raise
ValueError
(
'Invalid magic number %d in MNIST file %s'
%
(
magic
,
f
.
name
))
if
rows
!=
28
or
cols
!=
28
:
raise
ValueError
(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d'
%
(
f
.
name
,
rows
,
cols
))
def
check_labels_file_header
(
filename
):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with
open
(
filename
)
as
f
:
magic
=
read32
(
f
)
num_items
=
read32
(
f
)
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number %d in MNIST file %s'
%
(
magic
,
f
.
name
))
def
download
(
directory
,
filename
):
"""Download (and unzip) a file from the MNIST dataset, if it doesn't already exist."""
if
not
tf
.
gfile
.
Exists
(
directory
):
tf
.
gfile
.
MakeDirs
(
directory
)
filepath
=
os
.
path
.
join
(
directory
,
filename
)
if
tf
.
gfile
.
Exists
(
filepath
):
return
filepath
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url
=
'https://storage.googleapis.com/cvdf-datasets/mnist/'
+
filename
+
'.gz'
zipped_filepath
=
filepath
+
'.gz'
print
(
'Downloading %s to %s'
%
(
url
,
zipped_filepath
))
urllib
.
request
.
urlretrieve
(
url
,
zipped_filepath
)
with
gzip
.
open
(
zipped_filepath
,
'rb'
)
as
f_in
,
open
(
filepath
,
'wb'
)
as
f_out
:
shutil
.
copyfileobj
(
f_in
,
f_out
)
os
.
remove
(
zipped_filepath
)
return
filepath
def
dataset
(
directory
,
images_file
,
labels_file
):
images_file
=
download
(
directory
,
images_file
)
labels_file
=
download
(
directory
,
labels_file
)
check_image_file_header
(
images_file
)
check_labels_file_header
(
labels_file
)
def
decode_image
(
image
):
# Normalize from [0, 255] to [0.0, 1.0]
image
=
tf
.
decode_raw
(
image
,
tf
.
uint8
)
image
=
tf
.
cast
(
image
,
tf
.
float32
)
image
=
tf
.
reshape
(
image
,
[
784
])
return
image
/
255.0
def
one_hot_label
(
label
):
label
=
tf
.
decode_raw
(
label
,
tf
.
uint8
)
# tf.string -> tf.uint8
label
=
tf
.
reshape
(
label
,
[])
# label is a scalar
return
tf
.
one_hot
(
label
,
10
)
images
=
tf
.
data
.
FixedLengthRecordDataset
(
images_file
,
28
*
28
,
header_bytes
=
16
).
map
(
decode_image
)
labels
=
tf
.
data
.
FixedLengthRecordDataset
(
labels_file
,
1
,
header_bytes
=
8
).
map
(
one_hot_label
)
return
tf
.
data
.
Dataset
.
zip
((
images
,
labels
))
def
train
(
directory
):
"""tf.data.Dataset object for MNIST training data."""
return
dataset
(
directory
,
'train-images-idx3-ubyte'
,
'train-labels-idx1-ubyte'
)
def
test
(
directory
):
"""tf.data.Dataset object for MNIST test data."""
return
dataset
(
directory
,
't10k-images-idx3-ubyte'
,
't10k-labels-idx1-ubyte'
)
official/mnist/mnist.py
View file @
8e4a1e2e
...
@@ -22,19 +22,7 @@ import os
...
@@ -22,19 +22,7 @@ import os
import
sys
import
sys
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.examples.tutorials.mnist
import
input_data
import
dataset
def
train_dataset
(
data_dir
):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data
=
input_data
.
read_data_sets
(
data_dir
,
one_hot
=
True
).
train
return
tf
.
data
.
Dataset
.
from_tensor_slices
((
data
.
images
,
data
.
labels
))
def
eval_dataset
(
data_dir
):
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
data
=
input_data
.
read_data_sets
(
data_dir
,
one_hot
=
True
).
test
return
tf
.
data
.
Dataset
.
from_tensors
((
data
.
images
,
data
.
labels
))
class
Model
(
object
):
class
Model
(
object
):
...
@@ -151,10 +139,10 @@ def main(unused_argv):
...
@@ -151,10 +139,10 @@ def main(unused_argv):
# When choosing shuffle buffer sizes, larger sizes result in better
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
# enough dataset that we can easily shuffle the full epoch.
dataset
=
train
_dataset
(
FLAGS
.
data_dir
)
ds
=
dataset
.
train
(
FLAGS
.
data_dir
)
d
ataset
=
dataset
.
shuffle
(
buffer_size
=
50000
).
batch
(
FLAGS
.
batch_size
).
repeat
(
d
s
=
ds
.
cache
()
.
shuffle
(
buffer_size
=
50000
).
batch
(
FLAGS
.
batch_size
).
repeat
(
FLAGS
.
train_epochs
)
FLAGS
.
train_epochs
)
(
images
,
labels
)
=
d
ataset
.
make_one_shot_iterator
().
get_next
()
(
images
,
labels
)
=
d
s
.
make_one_shot_iterator
().
get_next
()
return
(
images
,
labels
)
return
(
images
,
labels
)
# Set up training hook that logs the training accuracy every 100 steps.
# Set up training hook that logs the training accuracy every 100 steps.
...
@@ -165,7 +153,8 @@ def main(unused_argv):
...
@@ -165,7 +153,8 @@ def main(unused_argv):
# Evaluate the model and print results
# Evaluate the model and print results
def
eval_input_fn
():
def
eval_input_fn
():
return
eval_dataset
(
FLAGS
.
data_dir
).
make_one_shot_iterator
().
get_next
()
return
dataset
.
test
(
FLAGS
.
data_dir
).
batch
(
FLAGS
.
batch_size
).
make_one_shot_iterator
().
get_next
()
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
print
()
print
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment