Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6e52c271
Commit
6e52c271
authored
Nov 06, 2017
by
Neal Wu
Browse files
Separate parse_and_preprocess into two different dataset.map calls, which also keeps tests passing
parent
807d6bde
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
24 deletions
+19
-24
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+19
-24
No files found.
official/resnet/cifar10_main.py
View file @
6e52c271
...
@@ -108,45 +108,38 @@ def parse_record(raw_record):
...
@@ -108,45 +108,38 @@ def parse_record(raw_record):
# Convert bytes to a vector of uint8 that is record_bytes long.
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector
=
tf
.
decode_raw
(
raw_record
,
tf
.
uint8
)
record_vector
=
tf
.
decode_raw
(
raw_record
,
tf
.
uint8
)
# The first byte represents the label, which we convert from uint8 to int32.
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label
=
tf
.
cast
(
record_vector
[
0
],
tf
.
int32
)
label
=
tf
.
cast
(
record_vector
[
0
],
tf
.
int32
)
label
=
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
# The remaining bytes after the label represent the image, which we reshape
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
# from [depth * height * width] to [depth, height, width].
depth_major
=
tf
.
reshape
(
record_vector
[
label_bytes
:
record_bytes
],
depth_major
=
tf
.
reshape
(
[
_DEPTH
,
_HEIGHT
,
_WIDTH
])
record_vector
[
label_bytes
:
record_bytes
],
[
_DEPTH
,
_HEIGHT
,
_WIDTH
])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
# float32.
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
return
image
,
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
return
image
,
label
def
train_preprocess_fn
(
image
):
"""Preprocess a single training image of layout [height, width, depth]."""
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_HEIGHT
+
8
,
_WIDTH
+
8
)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image
=
tf
.
random_crop
(
image
,
[
_HEIGHT
,
_WIDTH
,
_DEPTH
])
# Randomly flip the image horizontally.
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
def
preprocess_image
(
image
,
is_training
):
"""Preprocess a single image of layout [height, width, depth]."""
if
is_training
:
# Resize the image to add four extra pixels on each side.
image
=
tf
.
image
.
resize_image_with_crop_or_pad
(
image
,
_HEIGHT
+
8
,
_WIDTH
+
8
)
def
parse_and_preprocess
(
record
,
is_training
):
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
"""Parse and preprocess records in the CIFAR-10 dataset."""
image
=
tf
.
random_crop
(
image
,
[
_HEIGHT
,
_WIDTH
,
_DEPTH
])
image
,
label
=
parse_record
(
record
)
if
is_training
:
# Randomly flip the image horizontally.
image
=
t
rain_preprocess_fn
(
image
)
image
=
t
f
.
image
.
random_flip_left_right
(
image
)
# Subtract off the mean and divide by the variance of the pixels.
# Subtract off the mean and divide by the variance of the pixels.
image
=
tf
.
image
.
per_image_standardization
(
image
)
image
=
tf
.
image
.
per_image_standardization
(
image
)
return
image
,
label
return
image
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
...
@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# randomness, while smaller sizes have better performance.
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
parse_record
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
lambda
record
:
parse_and_preprocess
(
record
,
is_training
))
lambda
image
,
label
:
(
preprocess_image
(
image
,
is_training
),
label
))
dataset
=
dataset
.
prefetch
(
2
*
batch_size
)
dataset
=
dataset
.
prefetch
(
2
*
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# We call repeat after shuffling, rather than before, to prevent separate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment