Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4702de29
Unverified
Commit
4702de29
authored
Oct 27, 2017
by
Neal Wu
Committed by
GitHub
Oct 27, 2017
Browse files
Use FLAGS in main functions only + Updates to shuffling (#2601)
parent
edcd29f2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
111 additions
and
87 deletions
+111
-87
official/mnist/convert_to_records.py
official/mnist/convert_to_records.py
+9
-9
official/mnist/mnist.py
official/mnist/mnist.py
+26
-27
official/mnist/mnist_test.py
official/mnist/mnist_test.py
+2
-2
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+25
-19
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+12
-7
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+26
-16
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+11
-7
No files found.
official/mnist/convert_to_records.py
View file @
4702de29
...
...
@@ -50,11 +50,11 @@ def _bytes_feature(value):
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
def
convert_to
(
data
_
set
,
name
):
def
convert_to
(
dataset
,
name
,
directory
):
"""Converts a dataset to TFRecords."""
images
=
data
_
set
.
images
labels
=
data
_
set
.
labels
num_examples
=
data
_
set
.
num_examples
images
=
dataset
.
images
labels
=
dataset
.
labels
num_examples
=
dataset
.
num_examples
if
images
.
shape
[
0
]
!=
num_examples
:
raise
ValueError
(
'Images size %d does not match label size %d.'
%
...
...
@@ -63,7 +63,7 @@ def convert_to(data_set, name):
cols
=
images
.
shape
[
2
]
depth
=
images
.
shape
[
3
]
filename
=
os
.
path
.
join
(
FLAGS
.
directory
,
name
+
'.tfrecords'
)
filename
=
os
.
path
.
join
(
directory
,
name
+
'.tfrecords'
)
print
(
'Writing'
,
filename
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
filename
)
for
index
in
range
(
num_examples
):
...
...
@@ -80,15 +80,15 @@ def convert_to(data_set, name):
def
main
(
unused_argv
):
# Get the data.
data
_
sets
=
mnist
.
read_data_sets
(
FLAGS
.
directory
,
datasets
=
mnist
.
read_data_sets
(
FLAGS
.
directory
,
dtype
=
tf
.
uint8
,
reshape
=
False
,
validation_size
=
FLAGS
.
validation_size
)
# Convert to Examples and write the result to TFRecords.
convert_to
(
data
_
sets
.
train
,
'train'
)
convert_to
(
data
_
sets
.
validation
,
'validation'
)
convert_to
(
data
_
sets
.
test
,
'test'
)
convert_to
(
datasets
.
train
,
'train'
,
FLAGS
.
directory
)
convert_to
(
datasets
.
validation
,
'validation'
,
FLAGS
.
directory
)
convert_to
(
datasets
.
test
,
'test'
,
FLAGS
.
directory
)
if
__name__
==
'__main__'
:
...
...
official/mnist/mnist.py
View file @
4702de29
...
...
@@ -52,7 +52,7 @@ _NUM_IMAGES = {
}
def
input_fn
(
mode
,
batch_size
=
1
):
def
input_fn
(
is_training
,
filename
,
batch_size
=
1
,
num_epochs
=
1
):
"""A simple input_fn using the contrib.data input pipeline."""
def
example_parser
(
serialized_example
):
...
...
@@ -71,21 +71,15 @@ def input_fn(mode, batch_size=1):
label
=
tf
.
cast
(
features
[
'label'
],
tf
.
int32
)
return
image
,
tf
.
one_hot
(
label
,
10
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
tfrecords_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'train.tfrecords'
)
else
:
assert
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
,
'invalid mode'
tfrecords_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'test.tfrecords'
)
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
([
filename
])
assert
tf
.
gfile
.
Exists
(
tfrecords_file
),
(
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.'
)
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_NUM_IMAGES
[
'train'
])
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
([
tfrecords_file
])
# For training, repeat the dataset forever
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
(
num_epochs
)
# Map example_parser over dataset, and batch results by up to batch_size
dataset
=
dataset
.
map
(
...
...
@@ -96,13 +90,12 @@ def input_fn(mode, batch_size=1):
return
images
,
labels
def
mnist_model
(
inputs
,
mode
):
def
mnist_model
(
inputs
,
mode
,
data_format
):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
28
,
28
,
1
])
data_format
=
FLAGS
.
data_format
if
data_format
is
None
:
# When running on GPU, transpose the data from channels_last (NHWC) to
...
...
@@ -177,9 +170,9 @@ def mnist_model(inputs, mode):
return
logits
def
mnist_model_fn
(
features
,
labels
,
mode
):
def
mnist_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for MNIST."""
logits
=
mnist_model
(
features
,
mode
)
logits
=
mnist_model
(
features
,
mode
,
params
[
'data_format'
]
)
predictions
=
{
'classes'
:
tf
.
argmax
(
input
=
logits
,
axis
=
1
),
...
...
@@ -215,30 +208,36 @@ def mnist_model_fn(features, labels, mode):
def
main
(
unused_argv
):
# Make sure that training and testing data have been converted.
train_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'train.tfrecords'
)
test_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'test.tfrecords'
)
assert
(
tf
.
gfile
.
Exists
(
train_file
)
and
tf
.
gfile
.
Exists
(
test_file
)),
(
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.'
)
# Create the Estimator
mnist_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
mnist_model_fn
,
model_dir
=
FLAGS
.
model_dir
)
model_fn
=
mnist_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
params
=
{
'data_format'
:
FLAGS
.
data_format
})
#
Train the model
#
Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log
=
{
'train_accuracy'
:
'train_accuracy'
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
batches_per_epoch
=
_NUM_IMAGES
[
'train'
]
/
FLAGS
.
batch_size
# Train the model
mnist_classifier
.
train
(
input_fn
=
lambda
:
input_fn
(
tf
.
estimator
.
ModeKeys
.
TRAIN
,
FLAGS
.
batch_size
),
steps
=
FLAGS
.
train_epochs
*
batches_per
_epoch
,
input_fn
=
lambda
:
input_fn
(
True
,
train_file
,
FLAGS
.
batch_size
,
FLAGS
.
train
_epoch
s
)
,
hooks
=
[
logging_hook
])
# Evaluate the model and print results
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
lambda
:
input_fn
(
tf
.
estimator
.
ModeKeys
.
EVAL
))
input_fn
=
lambda
:
input_fn
(
False
,
test_file
,
FLAGS
.
batch_size
))
print
()
print
(
'Evaluation results:
\n
%s'
%
eval_results
)
print
(
'Evaluation results:
\n
\t
%s'
%
eval_results
)
if
__name__
==
'__main__'
:
...
...
official/mnist/mnist_test.py
View file @
4702de29
...
...
@@ -34,7 +34,8 @@ class BaseTest(tf.test.TestCase):
def
mnist_model_fn_helper
(
self
,
mode
):
features
,
labels
=
self
.
input_fn
()
image_count
=
features
.
shape
[
0
]
spec
=
mnist
.
mnist_model_fn
(
features
,
labels
,
mode
)
spec
=
mnist
.
mnist_model_fn
(
features
,
labels
,
mode
,
{
'data_format'
:
'channels_last'
})
predictions
=
spec
.
predictions
self
.
assertAllEqual
(
predictions
[
'probabilities'
].
shape
,
(
image_count
,
10
))
...
...
@@ -65,5 +66,4 @@ class BaseTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
mnist
.
FLAGS
=
mnist
.
parser
.
parse_args
()
tf
.
test
.
main
()
official/resnet/cifar10_main.py
View file @
4702de29
...
...
@@ -71,6 +71,8 @@ _NUM_IMAGES = {
'validation'
:
10000
,
}
_SHUFFLE_BUFFER
=
20000
def
record_dataset
(
filenames
):
"""Returns an input pipeline Dataset from `filenames`."""
...
...
@@ -78,9 +80,9 @@ def record_dataset(filenames):
return
tf
.
contrib
.
data
.
FixedLengthRecordDataset
(
filenames
,
record_bytes
)
def
get_filenames
(
is_training
):
def
get_filenames
(
is_training
,
data_dir
):
"""Returns a list of filenames."""
data_dir
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'cifar-10-batches-bin'
)
data_dir
=
os
.
path
.
join
(
data_dir
,
'cifar-10-batches-bin'
)
assert
os
.
path
.
exists
(
data_dir
),
(
'Run cifar10_download_and_extract.py first to download and extract the '
...
...
@@ -135,7 +137,7 @@ def train_preprocess_fn(image, label):
return
image
,
label
def
input_fn
(
is_training
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
Args:
...
...
@@ -145,42 +147,41 @@ def input_fn(is_training, num_epochs=1):
Returns:
A tuple of images and labels.
"""
dataset
=
record_dataset
(
get_filenames
(
is_training
))
dataset
=
record_dataset
(
get_filenames
(
is_training
,
data_dir
))
dataset
=
dataset
.
map
(
dataset_parser
,
num_threads
=
1
,
output_buffer_size
=
2
*
FLAGS
.
batch_size
)
output_buffer_size
=
2
*
batch_size
)
# For training, preprocess the image and shuffle.
if
is_training
:
dataset
=
dataset
.
map
(
train_preprocess_fn
,
num_threads
=
1
,
output_buffer_size
=
2
*
FLAGS
.
batch_size
)
output_buffer_size
=
2
*
batch_size
)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
buffer_size
=
int
(
0.4
*
_NUM_IMAGES
[
'train'
])
dataset
=
dataset
.
shuffle
(
buffer_size
=
buffer_size
)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
# Subtract off the mean and divide by the variance of the pixels.
dataset
=
dataset
.
map
(
lambda
image
,
label
:
(
tf
.
image
.
per_image_standardization
(
image
),
label
),
num_threads
=
1
,
output_buffer_size
=
2
*
FLAGS
.
batch_size
)
output_buffer_size
=
2
*
batch_size
)
dataset
=
dataset
.
repeat
(
num_epochs
)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
iterator
=
dataset
.
batch
(
FLAGS
.
batch_size
).
make_one_shot_iterator
()
iterator
=
dataset
.
batch
(
batch_size
).
make_one_shot_iterator
()
images
,
labels
=
iterator
.
get_next
()
return
images
,
labels
def
cifar10_model_fn
(
features
,
labels
,
mode
):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for CIFAR-10."""
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
network
=
resnet_model
.
cifar10_resnet_v2_generator
(
FLAGS
.
resnet_size
,
_NUM_CLASSES
,
FLAGS
.
data_format
)
params
[
'
resnet_size
'
]
,
_NUM_CLASSES
,
params
[
'
data_format
'
]
)
inputs
=
tf
.
reshape
(
features
,
[
-
1
,
_HEIGHT
,
_WIDTH
,
_DEPTH
])
logits
=
network
(
inputs
,
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
...
...
@@ -208,8 +209,8 @@ def cifar10_model_fn(features, labels, mode):
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
# Scale the learning rate linearly with the batch size. When the batch size
# is 128, the learning rate should be 0.1.
initial_learning_rate
=
0.1
*
FLAGS
.
batch_size
/
128
batches_per_epoch
=
_NUM_IMAGES
[
'train'
]
/
FLAGS
.
batch_size
initial_learning_rate
=
0.1
*
params
[
'
batch_size
'
]
/
128
batches_per_epoch
=
_NUM_IMAGES
[
'train'
]
/
params
[
'
batch_size
'
]
global_step
=
tf
.
train
.
get_or_create_global_step
()
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
...
...
@@ -256,7 +257,12 @@ def main(unused_argv):
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
)
cifar_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
cifar10_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
config
=
run_config
)
model_fn
=
cifar10_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
config
=
run_config
,
params
=
{
'resnet_size'
:
FLAGS
.
resnet_size
,
'data_format'
:
FLAGS
.
data_format
,
'batch_size'
:
FLAGS
.
batch_size
,
})
for
_
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_per_eval
):
tensors_to_log
=
{
...
...
@@ -270,12 +276,12 @@ def main(unused_argv):
cifar_classifier
.
train
(
input_fn
=
lambda
:
input_fn
(
is_training
=
True
,
num_epochs
=
FLAGS
.
epochs_per_eval
),
True
,
FLAGS
.
data_dir
,
FLAGS
.
batch_size
,
FLAGS
.
epochs_per_eval
),
hooks
=
[
logging_hook
])
# Evaluate the model and print results
eval_results
=
cifar_classifier
.
evaluate
(
input_fn
=
lambda
:
input_fn
(
is_training
=
Fals
e
))
input_fn
=
lambda
:
input_fn
(
False
,
FLAGS
.
data_dir
,
FLAGS
.
batch_siz
e
))
print
(
eval_results
)
...
...
official/resnet/cifar10_test.py
View file @
4702de29
...
...
@@ -26,6 +26,8 @@ import cifar10_main
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
ERROR
)
_BATCH_SIZE
=
128
class
BaseTest
(
tf
.
test
.
TestCase
):
...
...
@@ -58,20 +60,25 @@ class BaseTest(tf.test.TestCase):
self
.
assertAllEqual
(
pixel
,
np
.
array
([
0
,
1
,
2
]))
def
input_fn
(
self
):
features
=
tf
.
random_uniform
([
FLAGS
.
batch_size
,
32
,
32
,
3
])
features
=
tf
.
random_uniform
([
_BATCH_SIZE
,
32
,
32
,
3
])
labels
=
tf
.
random_uniform
(
[
FLAGS
.
batch_size
],
maxval
=
9
,
dtype
=
tf
.
int32
)
[
_BATCH_SIZE
],
maxval
=
9
,
dtype
=
tf
.
int32
)
return
features
,
tf
.
one_hot
(
labels
,
10
)
def
cifar10_model_fn_helper
(
self
,
mode
):
features
,
labels
=
self
.
input_fn
()
spec
=
cifar10_main
.
cifar10_model_fn
(
features
,
labels
,
mode
)
spec
=
cifar10_main
.
cifar10_model_fn
(
features
,
labels
,
mode
,
{
'resnet_size'
:
32
,
'data_format'
:
'channels_last'
,
'batch_size'
:
_BATCH_SIZE
,
})
predictions
=
spec
.
predictions
self
.
assertAllEqual
(
predictions
[
'probabilities'
].
shape
,
(
FLAGS
.
batch_size
,
10
))
(
_BATCH_SIZE
,
10
))
self
.
assertEqual
(
predictions
[
'probabilities'
].
dtype
,
tf
.
float32
)
self
.
assertAllEqual
(
predictions
[
'classes'
].
shape
,
(
FLAGS
.
batch_size
,))
self
.
assertAllEqual
(
predictions
[
'classes'
].
shape
,
(
_BATCH_SIZE
,))
self
.
assertEqual
(
predictions
[
'classes'
].
dtype
,
tf
.
int64
)
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
...
...
@@ -97,6 +104,4 @@ class BaseTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
cifar10_main
.
FLAGS
=
cifar10_main
.
parser
.
parse_args
()
FLAGS
=
cifar10_main
.
FLAGS
tf
.
test
.
main
()
official/resnet/imagenet_main.py
View file @
4702de29
...
...
@@ -73,16 +73,18 @@ _NUM_IMAGES = {
'validation'
:
50000
,
}
_SHUFFLE_BUFFER
=
1500
def
filenames
(
is_training
):
def
filenames
(
is_training
,
data_dir
):
"""Return filenames for dataset."""
if
is_training
:
return
[
os
.
path
.
join
(
FLAGS
.
data_dir
,
'train-%05d-of-01024'
%
i
)
os
.
path
.
join
(
data_dir
,
'train-%05d-of-01024'
%
i
)
for
i
in
range
(
0
,
1024
)]
else
:
return
[
os
.
path
.
join
(
FLAGS
.
data_dir
,
'validation-%05d-of-00128'
%
i
)
os
.
path
.
join
(
data_dir
,
'validation-%05d-of-00128'
%
i
)
for
i
in
range
(
0
,
128
)]
...
...
@@ -129,9 +131,11 @@ def dataset_parser(value, is_training):
return
image
,
tf
.
one_hot
(
label
,
_LABEL_CLASSES
)
def
input_fn
(
is_training
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input function which provides batches for train or eval."""
dataset
=
tf
.
contrib
.
data
.
Dataset
.
from_tensor_slices
(
filenames
(
is_training
))
dataset
=
tf
.
contrib
.
data
.
Dataset
.
from_tensor_slices
(
filenames
(
is_training
,
data_dir
))
if
is_training
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
1024
)
dataset
=
dataset
.
flat_map
(
tf
.
contrib
.
data
.
TFRecordDataset
)
...
...
@@ -141,23 +145,24 @@ def input_fn(is_training, num_epochs=1):
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_threads
=
5
,
output_buffer_size
=
FLAGS
.
batch_size
)
output_buffer_size
=
batch_size
)
if
is_training
:
buffer_size
=
1250
+
2
*
FLAGS
.
batch_size
dataset
=
dataset
.
shuffle
(
buffer_size
=
buffer_size
)
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
iterator
=
dataset
.
batch
(
FLAGS
.
batch_size
).
make_one_shot_iterator
()
iterator
=
dataset
.
batch
(
batch_size
).
make_one_shot_iterator
()
images
,
labels
=
iterator
.
get_next
()
return
images
,
labels
def
resnet_model_fn
(
features
,
labels
,
mode
):
def
resnet_model_fn
(
features
,
labels
,
mode
,
params
):
"""Our model_fn for ResNet to be used with our Estimator."""
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
network
=
resnet_model
.
imagenet_resnet_v2
(
FLAGS
.
resnet_size
,
_LABEL_CLASSES
,
FLAGS
.
data_format
)
params
[
'
resnet_size
'
]
,
_LABEL_CLASSES
,
params
[
'
data_format
'
]
)
logits
=
network
(
inputs
=
features
,
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
))
...
...
@@ -185,8 +190,8 @@ def resnet_model_fn(features, labels, mode):
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
# Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1.
initial_learning_rate
=
0.1
*
FLAGS
.
batch_size
/
256
batches_per_epoch
=
_NUM_IMAGES
[
'train'
]
/
FLAGS
.
batch_size
initial_learning_rate
=
0.1
*
params
[
'
batch_size
'
]
/
256
batches_per_epoch
=
_NUM_IMAGES
[
'train'
]
/
params
[
'
batch_size
'
]
global_step
=
tf
.
train
.
get_or_create_global_step
()
# Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs.
...
...
@@ -235,7 +240,12 @@ def main(unused_argv):
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
)
resnet_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
resnet_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
config
=
run_config
)
model_fn
=
resnet_model_fn
,
model_dir
=
FLAGS
.
model_dir
,
config
=
run_config
,
params
=
{
'resnet_size'
:
FLAGS
.
resnet_size
,
'data_format'
:
FLAGS
.
data_format
,
'batch_size'
:
FLAGS
.
batch_size
,
})
for
_
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_per_eval
):
tensors_to_log
=
{
...
...
@@ -250,12 +260,12 @@ def main(unused_argv):
print
(
'Starting a training cycle.'
)
resnet_classifier
.
train
(
input_fn
=
lambda
:
input_fn
(
is_training
=
True
,
num_epochs
=
FLAGS
.
epochs_per_eval
),
True
,
FLAGS
.
data_dir
,
FLAGS
.
batch_size
,
FLAGS
.
epochs_per_eval
),
hooks
=
[
logging_hook
])
print
(
'Starting to evaluate.'
)
eval_results
=
resnet_classifier
.
evaluate
(
input_fn
=
lambda
:
input_fn
(
is_training
=
Fals
e
))
input_fn
=
lambda
:
input_fn
(
False
,
FLAGS
.
data_dir
,
FLAGS
.
batch_siz
e
))
print
(
eval_results
)
...
...
official/resnet/imagenet_test.py
View file @
4702de29
...
...
@@ -26,6 +26,7 @@ import resnet_model
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
ERROR
)
_BATCH_SIZE
=
32
_LABEL_CLASSES
=
1001
...
...
@@ -125,10 +126,10 @@ class BaseTest(tf.test.TestCase):
def
input_fn
(
self
):
"""Provides random features and labels."""
features
=
tf
.
random_uniform
([
FLAGS
.
batch_size
,
224
,
224
,
3
])
features
=
tf
.
random_uniform
([
_BATCH_SIZE
,
224
,
224
,
3
])
labels
=
tf
.
one_hot
(
tf
.
random_uniform
(
[
FLAGS
.
batch_size
],
maxval
=
_LABEL_CLASSES
-
1
,
[
_BATCH_SIZE
],
maxval
=
_LABEL_CLASSES
-
1
,
dtype
=
tf
.
int32
),
_LABEL_CLASSES
)
...
...
@@ -139,13 +140,18 @@ class BaseTest(tf.test.TestCase):
tf
.
train
.
create_global_step
()
features
,
labels
=
self
.
input_fn
()
spec
=
imagenet_main
.
resnet_model_fn
(
features
,
labels
,
mode
)
spec
=
imagenet_main
.
resnet_model_fn
(
features
,
labels
,
mode
,
{
'resnet_size'
:
50
,
'data_format'
:
'channels_last'
,
'batch_size'
:
_BATCH_SIZE
,
})
predictions
=
spec
.
predictions
self
.
assertAllEqual
(
predictions
[
'probabilities'
].
shape
,
(
FLAGS
.
batch_size
,
_LABEL_CLASSES
))
(
_BATCH_SIZE
,
_LABEL_CLASSES
))
self
.
assertEqual
(
predictions
[
'probabilities'
].
dtype
,
tf
.
float32
)
self
.
assertAllEqual
(
predictions
[
'classes'
].
shape
,
(
FLAGS
.
batch_size
,))
self
.
assertAllEqual
(
predictions
[
'classes'
].
shape
,
(
_BATCH_SIZE
,))
self
.
assertEqual
(
predictions
[
'classes'
].
dtype
,
tf
.
int64
)
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
...
...
@@ -171,6 +177,4 @@ class BaseTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
imagenet_main
.
FLAGS
=
imagenet_main
.
parser
.
parse_args
()
FLAGS
=
imagenet_main
.
FLAGS
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment