Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f88def23
Unverified
Commit
f88def23
authored
Nov 07, 2017
by
k-w-w
Committed by
GitHub
Nov 07, 2017
Browse files
Merge pull request #2690 from tensorflow/tf-data
Changing tf.contrib.data to tf.data for release of tf 1.4
parents
4cfa0d3b
ae5adb59
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
47 deletions
+49
-47
official/mnist/mnist.py
official/mnist/mnist.py
+4
-4
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+35
-34
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+1
-1
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+6
-6
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+3
-2
No files found.
official/mnist/mnist.py
View file @
f88def23
...
...
@@ -53,7 +53,7 @@ _NUM_IMAGES = {
def
input_fn
(
is_training
,
filename
,
batch_size
=
1
,
num_epochs
=
1
):
"""A simple input_fn using the
contrib
.data input pipeline."""
"""A simple input_fn using the
tf
.data input pipeline."""
def
example_parser
(
serialized_example
):
"""Parses a single tf.Example into image and label tensors."""
...
...
@@ -71,8 +71,9 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
label
=
tf
.
cast
(
features
[
'label'
],
tf
.
int32
)
return
image
,
tf
.
one_hot
(
label
,
10
)
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
([
filename
])
dataset
=
tf
.
data
.
TFRecordDataset
([
filename
])
# Apply dataset transformations
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
...
...
@@ -84,8 +85,7 @@ def input_fn(is_training, filename, batch_size=1, num_epochs=1):
dataset
=
dataset
.
repeat
(
num_epochs
)
# Map example_parser over dataset, and batch results by up to batch_size
dataset
=
dataset
.
map
(
example_parser
,
num_threads
=
1
,
output_buffer_size
=
batch_size
)
dataset
=
dataset
.
map
(
example_parser
).
prefetch
(
batch_size
)
dataset
=
dataset
.
batch
(
batch_size
)
iterator
=
dataset
.
make_one_shot_iterator
()
images
,
labels
=
iterator
.
get_next
()
...
...
official/resnet/cifar10_main.py
View file @
f88def23
...
...
@@ -71,13 +71,11 @@ _NUM_IMAGES = {
'validation'
:
10000
,
}
_SHUFFLE_BUFFER
=
20000
def
record_dataset
(
filenames
):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes
=
_HEIGHT
*
_WIDTH
*
_DEPTH
+
1
return
tf
.
contrib
.
data
.
FixedLengthRecordDataset
(
filenames
,
record_bytes
)
return
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
record_bytes
)
def
get_filenames
(
is_training
,
data_dir
):
...
...
@@ -97,74 +95,77 @@ def get_filenames(is_training, data_dir):
return
[
os
.
path
.
join
(
data_dir
,
'test_batch.bin'
)]
def
dataset_parser
(
value
):
"""Parse
a
CIFAR-10
record from value
."""
def
parse_record
(
raw_record
):
"""Parse CIFAR-10
image and label from a raw record
."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes
=
1
image_bytes
=
_HEIGHT
*
_WIDTH
*
_DEPTH
record_bytes
=
label_bytes
+
image_bytes
# Convert
from a string
to a vector of uint8 that is record_bytes long.
r
aw_r
ecor
d
=
tf
.
decode_raw
(
value
,
tf
.
uint8
)
# Convert
bytes
to a vector of uint8 that is record_bytes long.
r
ecord_v
ec
t
or
=
tf
.
decode_raw
(
raw_record
,
tf
.
uint8
)
# The first byte represents the label, which we convert from uint8 to int32.
label
=
tf
.
cast
(
raw_record
[
0
],
tf
.
int32
)
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label
=
tf
.
cast
(
record_vector
[
0
],
tf
.
int32
)
label
=
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major
=
tf
.
reshape
(
raw_record
[
label_bytes
:
record_bytes
],
[
_DEPTH
,
_HEIGHT
,
_WIDTH
])
depth_major
=
tf
.
reshape
(
record_vector
[
label_bytes
:
record_bytes
],
[
_DEPTH
,
_HEIGHT
,
_WIDTH
])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
return
image
,
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
return
image
,
label
def
train_preprocess_fn
(
image
,
label
):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_HEIGHT
+
8
,
_WIDTH
+
8
)
def
preprocess_image
(
image
,
is_training
):
"""Preprocess a single image of layout [height, width, depth]."""
if
is_training
:
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_HEIGHT
+
8
,
_WIDTH
+
8
)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_HEIGHT
,
_WIDTH
,
_DEPTH
])
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_HEIGHT
,
_WIDTH
,
_DEPTH
])
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
,
label
# Subtract off the mean and divide by the variance of the pixels.
image
=
tf
.
image
.
per_image_standardization
(
image
)
return
image
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input_fn using the
contrib
.data input pipeline for CIFAR-10 dataset.
"""Input_fn using the
tf
.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
Returns:
A tuple of images and labels.
"""
dataset
=
record_dataset
(
get_filenames
(
is_training
,
data_dir
))
dataset
=
dataset
.
map
(
dataset_parser
,
num_threads
=
1
,
output_buffer_size
=
2
*
batch_size
)
# For training, preprocess the image and shuffle.
if
is_training
:
dataset
=
dataset
.
map
(
train_preprocess_fn
,
num_threads
=
1
,
output_buffer_size
=
2
*
batch_size
)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
# randomness, while smaller sizes have better performance. Because CIFAR-10
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_NUM_IMAGES
[
'train'
])
# Subtract off the mean and divide by the variance of the pixels.
dataset
=
dataset
.
map
(
parse_record
)
dataset
=
dataset
.
map
(
lambda
image
,
label
:
(
tf
.
image
.
per_image_standardization
(
image
),
label
)
,
num_threads
=
1
,
output_buffer_size
=
2
*
batch_size
)
lambda
image
,
label
:
(
preprocess_image
(
image
,
is_training
),
label
)
)
dataset
=
dataset
.
prefetch
(
2
*
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
...
...
official/resnet/cifar10_test.py
View file @
f88def23
...
...
@@ -44,7 +44,7 @@ class BaseTest(tf.test.TestCase):
data_file
.
close
()
fake_dataset
=
cifar10_main
.
record_dataset
(
filename
)
fake_dataset
=
fake_dataset
.
map
(
cifar10_main
.
dataset_parser
)
fake_dataset
=
fake_dataset
.
map
(
cifar10_main
.
parse_record
)
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
self
.
assertEqual
(
label
.
get_shape
().
as_list
(),
[
10
])
...
...
official/resnet/imagenet_main.py
View file @
f88def23
...
...
@@ -134,23 +134,23 @@ def dataset_parser(value, is_training):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input function which provides batches for train or eval."""
dataset
=
tf
.
contrib
.
data
.
Dataset
.
from_tensor_slices
(
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
(
is_training
,
data_dir
))
if
is_training
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
_FILE_SHUFFLE_BUFFER
)
dataset
=
dataset
.
flat_map
(
tf
.
contrib
.
data
.
TFRecordDataset
)
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_threads
=
5
,
output_buffer_size
=
batch_size
)
dataset
=
dataset
.
flat_map
(
tf
.
data
.
TFRecordDataset
)
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_parallel_calls
=
5
)
dataset
=
dataset
.
prefetch
(
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
...
...
official/wide_deep/wide_deep.py
View file @
f88def23
...
...
@@ -178,12 +178,13 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
return
features
,
tf
.
equal
(
labels
,
'>50K'
)
# Extract lines from input files using the Dataset API.
dataset
=
tf
.
contrib
.
data
.
TextLineDataset
(
data_file
)
dataset
=
dataset
.
map
(
parse_csv
,
num_threads
=
5
)
dataset
=
tf
.
data
.
TextLineDataset
(
data_file
)
if
shuffle
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
parse_csv
,
num_parallel_calls
=
5
)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment