Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0f2bf200
"tools/vscode:/vscode.git/clone" did not exist on "9a336696104834d836812037f43489b0d36f51ea"
Commit
0f2bf200
authored
Jul 28, 2017
by
Marianne Linhares Monteiro
Committed by
GitHub
Jul 28, 2017
Browse files
Using TFRecords instead of TextLineDataset
parent
6f6bc501
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
35 deletions
+69
-35
tutorials/image/cifar10_estimator/cifar10.py
tutorials/image/cifar10_estimator/cifar10.py
+69
-35
No files found.
tutorials/image/cifar10_estimator/cifar10.py
View file @
0f2bf200
...
@@ -27,8 +27,6 @@ import tensorflow as tf
...
@@ -27,8 +27,6 @@ import tensorflow as tf
HEIGHT
=
32
HEIGHT
=
32
WIDTH
=
32
WIDTH
=
32
DEPTH
=
3
DEPTH
=
3
NUM_CLASSES
=
10
class
Cifar10DataSet
(
object
):
class
Cifar10DataSet
(
object
):
"""Cifar10 data set.
"""Cifar10 data set.
...
@@ -36,45 +34,81 @@ class Cifar10DataSet(object):
...
@@ -36,45 +34,81 @@ class Cifar10DataSet(object):
Described by http://www.cs.toronto.edu/~kriz/cifar.html.
Described by http://www.cs.toronto.edu/~kriz/cifar.html.
"""
"""
def
__init__
(
self
,
data_dir
):
def
__init__
(
self
,
data_dir
,
subset
=
'train'
,
use_distortion
=
True
):
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
self
.
subset
=
subset
def
read_all_data
(
self
,
subset
=
'train'
):
self
.
use_distortion
=
use_distortion
"""Reads from data file and return images and labels in a numpy array."""
if
subset
==
'train'
:
def
get_filenames
(
self
):
filenames
=
[
if
self
.
subset
==
'train'
:
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_%d'
%
i
)
return
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_%d.tfrecords'
%
i
)
for
i
in
xrange
(
1
,
5
)
for
i
in
xrange
(
1
,
5
)
]
]
elif
subset
==
'validation'
:
elif
self
.
subset
==
'validation'
:
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_5'
)]
return
[
os
.
path
.
join
(
self
.
data_dir
,
'data_batch_5
.tfrecords
'
)]
elif
subset
==
'eval'
:
elif
self
.
subset
==
'eval'
:
filenames
=
[
os
.
path
.
join
(
self
.
data_dir
,
'test_batch'
)]
return
[
os
.
path
.
join
(
self
.
data_dir
,
'test_batch
.tfrecords
'
)]
else
:
else
:
raise
ValueError
(
'Invalid data subset "%s"'
%
subset
)
raise
ValueError
(
'Invalid data subset "%s"'
%
self
.
subset
)
inputs
=
[]
def
parser
(
self
,
serialized_example
):
for
filename
in
filenames
:
"""Parses a single tf.Example into image and label tensors."""
with
tf
.
gfile
.
Open
(
filename
,
'r'
)
as
f
:
# Dimensions of the images in the CIFAR-10 dataset.
inputs
.
append
(
cPickle
.
load
(
f
))
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
all_images
=
np
.
concatenate
([
each_input
[
'data'
]
# input format.
for
each_input
in
inputs
]).
astype
(
np
.
float32
)
features
=
tf
.
parse_single_example
(
all_labels
=
np
.
concatenate
([
each_input
[
'labels'
]
for
each_input
in
inputs
])
serialized_example
,
return
all_images
,
all_labels
features
=
{
"image"
:
tf
.
FixedLenFeature
([],
tf
.
string
),
"label"
:
tf
.
FixedLenFeature
([],
tf
.
int64
),
})
image
=
tf
.
decode_raw
(
features
[
"image"
],
tf
.
uint8
)
image
.
set_shape
([
3
*
32
*
32
])
# Reshape from [depth * height * width] to [depth, height, width].
image
=
tf
.
transpose
(
tf
.
reshape
(
image
,
[
3
,
32
,
32
]),
[
1
,
2
,
0
])
label
=
tf
.
cast
(
features
[
"label"
],
tf
.
int32
)
@
staticmethod
# Custom preprocessing .
def
preprocess
(
image
,
is_training
,
distortion
):
image
=
self
.
preprocess
(
image
)
with
tf
.
name_scope
(
'preprocess'
):
# Read image layout as flattened CHW.
print
(
image
,
label
)
image
=
tf
.
reshape
(
image
,
[
DEPTH
,
HEIGHT
,
WIDTH
])
return
image
,
label
# Convert to NHWC layout, compatible with TF image preprocessing APIs
image
=
tf
.
transpose
(
image
,
[
1
,
2
,
0
])
def
make_batch
(
self
,
batch_size
):
if
is_training
and
distortion
:
"""Read the images and labels from 'filenames'."""
# Pad 4 pixels on each dimension of feature map, done in mini-batch
filenames
=
self
.
get_filenames
()
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
40
,
40
)
record_bytes
=
(
32
*
32
*
3
)
+
1
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
DEPTH
])
# Repeat infinitely.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
(
filenames
).
repeat
()
return
image
# Parse records.
dataset
=
dataset
.
map
(
self
.
parser
,
num_threads
=
batch_size
,
output_buffer_size
=
2
*
batch_size
)
# Potentially shuffle records.
if
self
.
subset
==
'train'
:
min_queue_examples
=
int
(
Cifar10DataSet
.
num_examples_per_epoch
(
self
.
subset
)
*
0.4
)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
dataset
=
dataset
.
shuffle
(
buffer_size
=
min_queue_examples
+
3
*
batch_size
)
# Batch it up.
dataset
=
dataset
.
batch
(
batch_size
)
iterator
=
dataset
.
make_one_shot_iterator
()
image_batch
,
label_batch
=
iterator
.
get_next
()
print
(
image_batch
,
label_batch
)
return
image_batch
,
label_batch
def
preprocess
(
self
,
image
):
"""Preprocess a single image in [height, width, depth] layout."""
if
self
.
subset
==
'train'
and
self
.
use_distortion
:
# Pad 4 pixels on each dimension of feature map, done in mini-batch
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
40
,
40
)
image
=
tf
.
random_crop
(
image
,
[
HEIGHT
,
WIDTH
,
DEPTH
])
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
@
staticmethod
@
staticmethod
def
num_examples_per_epoch
(
subset
=
'train'
):
def
num_examples_per_epoch
(
subset
=
'train'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment