Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0f2bf200
Commit
0f2bf200
authored
Jul 28, 2017
by
Marianne Linhares Monteiro
Committed by
GitHub
Jul 28, 2017
Browse files
Using TFRecords instead of TextLineDataset
parent
6f6bc501
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
35 deletions
+69
-35
tutorials/image/cifar10_estimator/cifar10.py
tutorials/image/cifar10_estimator/cifar10.py
+69
-35
No files found.
tutorials/image/cifar10_estimator/cifar10.py
View file @
0f2bf200
...
...
@@ -27,8 +27,6 @@ import tensorflow as tf
HEIGHT
=
32
WIDTH
=
32
DEPTH
=
3
NUM_CLASSES
=
10
class
Cifar10DataSet
(
object
):
"""Cifar10 data set.
...
...
@@ -36,45 +34,81 @@ class Cifar10DataSet(object):
Described by http://www.cs.toronto.edu/~kriz/cifar.html.
"""
def
__init__
(
self
,
data_dir
):
def
__init__
(
self
,
data_dir
,
subset
=
'train'
,
use_distortion
=
True
):
self
.
data_dir
=
data_dir
def
read_all_data
(
self
,
subset
=
'train'
):
"""Reads from data file and return images and labels in a numpy array."""
if
subset
==
'train'
:
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_%d'
%
i
)
self
.
subset
=
subset
self
.
use_distortion
=
use_distortion
def
get_filenames
(
self
):
if
self
.
subset
==
'train'
:
return
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_%d.tfrecords'
%
i
)
for
i
in
xrange
(
1
,
5
)
]
elif
subset
==
'validation'
:
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_5'
)]
elif
subset
==
'eval'
:
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
'test_batch'
)]
elif
self
.
subset
==
'validation'
:
return
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_5
.tfrecords
'
)]
elif
self
.
subset
==
'eval'
:
return
[
os
.
path
.
join
(
self
.
data_dir
,
'test_batch
.tfrecords
'
)]
else
:
raise
ValueError
(
'Invalid data subset "%s"'
%
subset
)
raise
ValueError
(
'Invalid data subset "%s"'
%
self
.
subset
)
inputs
=
[]
for
filename
in
filenames
:
with
tf
.
gfile
.
Open
(
filename
,
'r'
)
as
f
:
inputs
.
append
(
cPickle
.
load
(
f
))
all_images
=
np
.
concatenate
([
each_input
[
'data'
]
for
each_input
in
inputs
]).
astype
(
np
.
float32
)
all_labels
=
np
.
concatenate
([
each_input
[
'labels'
]
for
each_input
in
inputs
])
return
all_images
,
all_labels
def
parser
(
self
,
serialized_example
):
"""Parses a single tf.Example into image and label tensors."""
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
features
=
tf
.
parse_single_example
(
serialized_example
,
features
=
{
"image"
:
tf
.
FixedLenFeature
([],
tf
.
string
),
"label"
:
tf
.
FixedLenFeature
([],
tf
.
int64
),
})
image
=
tf
.
decode_raw
(
features
[
"image"
],
tf
.
uint8
)
image
.
set_shape
([
3
*
32
*
32
])
# Reshape from [depth * height * width] to [depth, height, width].
image
=
tf
.
transpose
(
tf
.
reshape
(
image
,
[
3
,
32
,
32
]),
[
1
,
2
,
0
])
label
=
tf
.
cast
(
features
[
"label"
],
tf
.
int32
)
@
staticmethod
def
preprocess
(
image
,
is_training
,
distortion
):
with
tf
.
name_scope
(
'preprocess'
):
# Read image layout as flattened CHW.
image
=
tf
.
reshape
(
image
,
[
DEPTH
,
HEIGHT
,
WIDTH
])
# Convert to NHWC layout, compatible with TF image preprocessing APIs
image
=
tf
.
transpose
(
image
,
[
1
,
2
,
0
])
if
is_training
and
distortion
:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
40
,
40
)
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
DEPTH
])
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
# Custom preprocessing .
image
=
self
.
preprocess
(
image
)
print
(
image
,
label
)
return
image
,
label
def
make_batch
(
self
,
batch_size
):
"""Read the images and labels from 'filenames'."""
filenames
=
self
.
get_filenames
()
record_bytes
=
(
32
*
32
*
3
)
+
1
# Repeat infinitely.
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
# Parse records.
dataset
=
dataset
.
map
(
self
.
parser
,
num_threads
=
batch_size
,
output_buffer_size
=
2
*
batch_size
)
# Potentially shuffle records.
if
self
.
subset
==
'train'
:
min_queue_examples
=
int
(
Cifar10DataSet
.
num_examples_per_epoch
(
self
.
subset
)
*
0.4
)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset
=
dataset
.
shuffle
(
buffer_size
=
min_queue_examples
+
3
*
batch_size
)
# Batch it up.
dataset
=
dataset
.
batch
(
batch_size
)
iterator
=
dataset
.
make_one_shot_iterator
()
image_batch
,
label_batch
=
iterator
.
get_next
()
print
(
image_batch
,
label_batch
)
return
image_batch
,
label_batch
def
preprocess
(
self
,
image
):
"""Preprocess a single image in [height, width, depth] layout."""
if
self
.
subset
==
'train'
and
self
.
use_distortion
:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
40
,
40
)
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
DEPTH
])
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
@
staticmethod
def
num_examples_per_epoch
(
subset
=
'train'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment