Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
807d6bde
Commit
807d6bde
authored
Nov 06, 2017
by
Kathy Wu
Browse files
Fixed cifar 10 tests
parent
e5f88ad6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
10 deletions
+16
-10
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+15
-9
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+1
-1
No files found.
official/resnet/cifar10_main.py
View file @
807d6bde
...
@@ -97,8 +97,8 @@ def get_filenames(is_training, data_dir):
...
@@ -97,8 +97,8 @@ def get_filenames(is_training, data_dir):
return
[
os
.
path
.
join
(
data_dir
,
'test_batch.bin'
)]
return
[
os
.
path
.
join
(
data_dir
,
'test_batch.bin'
)]
def
parse_
and_preprocess_
record
(
raw_record
,
is_training
):
def
parse_record
(
raw_record
):
"""Parse
and preprocess a
CIFAR-10 image and label from a raw record."""
"""Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
# of bytes for each.
label_bytes
=
1
label_bytes
=
1
...
@@ -120,12 +120,6 @@ def parse_and_preprocess_record(raw_record, is_training):
...
@@ -120,12 +120,6 @@ def parse_and_preprocess_record(raw_record, is_training):
# float32.
# float32.
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
if
is_training
:
image
=
train_preprocess_fn
(
image
)
# Subtract off the mean and divide by the variance of the pixels.
image
=
tf
.
image
.
per_image_standardization
(
image
)
return
image
,
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
return
image
,
tf
.
one_hot
(
label
,
_NUM_CLASSES
)
...
@@ -143,6 +137,18 @@ def train_preprocess_fn(image):
...
@@ -143,6 +137,18 @@ def train_preprocess_fn(image):
return
image
return
image
def
parse_and_preprocess
(
record
,
is_training
):
"""Parse and preprocess records in the CIFAR-10 dataset."""
image
,
label
=
parse_record
(
record
)
if
is_training
:
image
=
train_preprocess_fn
(
image
)
# Subtract off the mean and divide by the variance of the pixels.
image
=
tf
.
image
.
per_image_standardization
(
image
)
return
image
,
label
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
...
@@ -163,7 +169,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -163,7 +169,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
lambda
record
:
parse_and_preprocess
_record
(
record
,
is_training
))
lambda
record
:
parse_and_preprocess
(
record
,
is_training
))
dataset
=
dataset
.
prefetch
(
2
*
batch_size
)
dataset
=
dataset
.
prefetch
(
2
*
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# We call repeat after shuffling, rather than before, to prevent separate
...
...
official/resnet/cifar10_test.py
View file @
807d6bde
...
@@ -44,7 +44,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -44,7 +44,7 @@ class BaseTest(tf.test.TestCase):
data_file
.
close
()
data_file
.
close
()
fake_dataset
=
cifar10_main
.
record_dataset
(
filename
)
fake_dataset
=
cifar10_main
.
record_dataset
(
filename
)
fake_dataset
=
fake_dataset
.
map
(
cifar10_main
.
dataset_parser
)
fake_dataset
=
fake_dataset
.
map
(
cifar10_main
.
parse_record
)
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
self
.
assertEqual
(
label
.
get_shape
().
as_list
(),
[
10
])
self
.
assertEqual
(
label
.
get_shape
().
as_list
(),
[
10
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment