Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5ddd7e55
Commit
5ddd7e55
authored
Oct 03, 2017
by
Neal Wu
Committed by
GitHub
Oct 03, 2017
Browse files
Fix lint issues in the official models (#2487)
parent
385e669e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
53 additions
and
48 deletions
+53
-48
official/mnist/convert_to_records.py
official/mnist/convert_to_records.py
+2
-2
official/mnist/mnist.py
official/mnist/mnist.py
+10
-8
official/mnist/mnist_test.py
official/mnist/mnist_test.py
+0
-1
official/resnet/cifar10_download_and_extract.py
official/resnet/cifar10_download_and_extract.py
+2
-2
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+12
-9
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+4
-7
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+3
-2
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+3
-5
official/resnet/resnet_model.py
official/resnet/resnet_model.py
+17
-12
No files found.
official/mnist/convert_to_records.py
View file @
5ddd7e55
...
...
@@ -93,5 +93,5 @@ def main(unused_argv):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
FLAGS
=
parser
.
parse_args
()
tf
.
app
.
run
()
FLAGS
,
unparsed
=
parser
.
parse_
known_
args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
official/mnist/mnist.py
View file @
5ddd7e55
...
...
@@ -19,8 +19,8 @@ from __future__ import print_function
import
argparse
import
os
import
sys
import
numpy
as
np
import
tensorflow
as
tf
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -42,7 +42,7 @@ parser.add_argument('--steps', type=int, default=20000,
def
input_fn
(
mode
,
batch_size
=
1
):
"""A simple input_fn using the contrib.data input pipeline."""
def
parser
(
serialized_example
):
def
example_
parser
(
serialized_example
):
"""Parses a single tf.Example into image and label tensors."""
features
=
tf
.
parse_single_example
(
serialized_example
,
...
...
@@ -64,8 +64,9 @@ def input_fn(mode, batch_size=1):
assert
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
,
'invalid mode'
tfrecords_file
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'test.tfrecords'
)
assert
tf
.
gfile
.
Exists
(
tfrecords_file
),
(
'Run convert_to_records.py first to '
'convert the MNIST data to TFRecord file format.'
)
assert
tf
.
gfile
.
Exists
(
tfrecords_file
),
(
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.'
)
dataset
=
tf
.
contrib
.
data
.
TFRecordDataset
([
tfrecords_file
])
...
...
@@ -73,8 +74,9 @@ def input_fn(mode, batch_size=1):
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
dataset
=
dataset
.
repeat
()
# Map the parser over dataset, and batch results by up to batch_size
dataset
=
dataset
.
map
(
parser
,
num_threads
=
1
,
output_buffer_size
=
batch_size
)
# Map example_parser over dataset, and batch results by up to batch_size
dataset
=
dataset
.
map
(
example_parser
,
num_threads
=
1
,
output_buffer_size
=
batch_size
)
dataset
=
dataset
.
batch
(
batch_size
)
images
,
labels
=
dataset
.
make_one_shot_iterator
().
get_next
()
...
...
@@ -223,5 +225,5 @@ def main(unused_argv):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
FLAGS
=
parser
.
parse_args
()
tf
.
app
.
run
()
FLAGS
,
unparsed
=
parser
.
parse_
known_
args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
official/mnist/mnist_test.py
View file @
5ddd7e55
...
...
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
import
mnist
...
...
official/resnet/cifar10_download_and_extract.py
View file @
5ddd7e55
...
...
@@ -46,8 +46,8 @@ def main(unused_argv):
if
not
os
.
path
.
exists
(
filepath
):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading %s %.1f%%'
%
(
filename
,
100.0
*
count
*
block_size
/
total_size
))
sys
.
stdout
.
write
(
'
\r
>> Downloading %s %.1f%%'
%
(
filename
,
100.0
*
count
*
block_size
/
total_size
))
sys
.
stdout
.
flush
()
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
DATA_URL
,
filepath
,
_progress
)
...
...
official/resnet/cifar10_main.py
View file @
5ddd7e55
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the CIFAR-10 dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -19,9 +20,7 @@ from __future__ import print_function
import
argparse
import
os
import
sys
import
numpy
as
np
import
tensorflow
as
tf
import
resnet_model
...
...
@@ -75,12 +74,13 @@ def record_dataset(filenames):
return
tf
.
contrib
.
data
.
FixedLengthRecordDataset
(
filenames
,
record_bytes
)
def
filenames
(
mode
):
def
get_
filenames
(
mode
):
"""Returns a list of filenames based on 'mode'."""
data_dir
=
os
.
path
.
join
(
FLAGS
.
data_dir
,
'cifar-10-batches-bin'
)
assert
os
.
path
.
exists
(
data_dir
),
(
'Run cifar10_download_and_extract.py first '
'to download and extract the CIFAR-10 data.'
)
assert
os
.
path
.
exists
(
data_dir
),
(
'Run cifar10_download_and_extract.py first to download and extract the '
'CIFAR-10 data.'
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
return
[
...
...
@@ -137,10 +137,13 @@ def input_fn(mode, batch_size):
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
Args:
mode: Standard names for model modes
(
tf.estimator
s
.ModeKeys
)
.
mode: Standard names for model modes
from
tf.estimator.ModeKeys.
batch_size: The number of samples per batch of input requested.
Returns:
A tuple of images and labels.
"""
dataset
=
record_dataset
(
filenames
(
mode
))
dataset
=
record_dataset
(
get_
filenames
(
mode
))
# For training repeat forever.
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
...
...
@@ -227,7 +230,7 @@ def cifar10_model_fn(features, labels, mode):
else
:
train_op
=
None
accuracy
=
tf
.
metrics
.
accuracy
(
accuracy
=
tf
.
metrics
.
accuracy
(
tf
.
argmax
(
labels
,
axis
=
1
),
predictions
[
'classes'
])
metrics
=
{
'accuracy'
:
accuracy
}
...
...
@@ -250,7 +253,7 @@ def main(unused_argv):
cifar_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
cifar10_model_fn
,
model_dir
=
FLAGS
.
model_dir
)
for
cycle
in
range
(
FLAGS
.
train_steps
//
FLAGS
.
steps_per_eval
):
for
_
in
range
(
FLAGS
.
train_steps
//
FLAGS
.
steps_per_eval
):
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'cross_entropy'
:
'cross_entropy'
,
...
...
official/resnet/cifar10_test.py
View file @
5ddd7e55
...
...
@@ -17,9 +17,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
from
tempfile
import
mkdtemp
from
tempfile
import
mkstemp
import
numpy
as
np
...
...
@@ -36,13 +33,13 @@ class BaseTest(tf.test.TestCase):
fake_data
=
bytearray
()
fake_data
.
append
(
7
)
for
i
in
xrange
(
3
):
for
j
in
xrange
(
1024
):
for
_
in
xrange
(
1024
):
fake_data
.
append
(
i
)
_
,
filename
=
mkstemp
(
dir
=
self
.
get_temp_dir
())
file
=
open
(
filename
,
'wb'
)
file
.
write
(
fake_data
)
file
.
close
()
data_
file
=
open
(
filename
,
'wb'
)
data_
file
.
write
(
fake_data
)
data_
file
.
close
()
fake_dataset
=
cifar10_main
.
record_dataset
(
filename
)
fake_dataset
=
fake_dataset
.
map
(
cifar10_main
.
dataset_parser
)
...
...
official/resnet/imagenet_main.py
View file @
5ddd7e55
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -163,7 +164,7 @@ def input_fn(is_training):
def
resnet_model_fn
(
features
,
labels
,
mode
):
"""
Our model_fn for ResNet to be used with our Estimator."""
"""Our model_fn for ResNet to be used with our Estimator."""
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
logits
=
network
(
...
...
@@ -239,7 +240,7 @@ def main(unused_argv):
resnet_classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
resnet_model_fn
,
model_dir
=
FLAGS
.
model_dir
)
for
cycle
in
range
(
FLAGS
.
train_steps
//
FLAGS
.
steps_per_eval
):
for
_
in
range
(
FLAGS
.
train_steps
//
FLAGS
.
steps_per_eval
):
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'cross_entropy'
:
'cross_entropy'
,
...
...
official/resnet/imagenet_test.py
View file @
5ddd7e55
...
...
@@ -34,11 +34,9 @@ class BaseTest(tf.test.TestCase):
def
tensor_shapes_helper
(
self
,
resnet_size
,
with_gpu
=
False
):
"""Checks the tensor shapes after each phase of the ResNet model."""
def
reshape
(
shape
):
"""Returns the expected dimensions depending on if gpu is being used.
If a GPU is used for the test, the shape is returned (already in NCHW
form). When GPU is not used, the shape is converted to NHWC.
"""
"""Returns the expected dimensions depending on if a GPU is being used."""
# If a GPU is used for the test, the shape is returned (already in NCHW
# form). When GPU is not used, the shape is converted to NHWC.
if
with_gpu
:
return
shape
return
shape
[
0
],
shape
[
2
],
shape
[
3
],
shape
[
1
]
...
...
official/resnet/resnet_model.py
View file @
5ddd7e55
...
...
@@ -40,8 +40,8 @@ _BATCH_NORM_EPSILON = 1e-5
def
batch_norm_relu
(
inputs
,
is_training
,
data_format
):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant performance boost.
#
See
https://www.tensorflow.org/performance/performance_guide#common_fused_ops
# We set fused=True for a significant performance boost.
See
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
inputs
=
tf
.
layers
.
batch_normalization
(
inputs
=
inputs
,
axis
=
1
if
data_format
==
'channels_first'
else
3
,
momentum
=
_BATCH_NORM_DECAY
,
epsilon
=
_BATCH_NORM_EPSILON
,
center
=
True
,
...
...
@@ -78,11 +78,9 @@ def fixed_padding(inputs, kernel_size, data_format):
def
conv2d_fixed_padding
(
inputs
,
filters
,
kernel_size
,
strides
,
data_format
):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
"""
"""Strided 2-D convolution with explicit padding."""
# The padding is consistent and is based only on `kernel_size`, not on the
# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
if
strides
>
1
:
inputs
=
fixed_padding
(
inputs
,
kernel_size
,
data_format
)
...
...
@@ -210,7 +208,7 @@ def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name,
inputs
=
block_fn
(
inputs
,
filters
,
is_training
,
projection_shortcut
,
strides
,
data_format
)
for
i
in
range
(
1
,
blocks
):
for
_
in
range
(
1
,
blocks
):
inputs
=
block_fn
(
inputs
,
filters
,
is_training
,
None
,
1
,
data_format
)
return
tf
.
identity
(
inputs
,
name
)
...
...
@@ -228,6 +226,9 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
Returns:
The model function that takes in `inputs` and `is_training` and
returns the output tensor of the ResNet model.
Raises:
ValueError: If `resnet_size` is invalid.
"""
if
resnet_size
%
6
!=
2
:
raise
ValueError
(
'resnet_size must be 6n + 2:'
,
resnet_size
)
...
...
@@ -235,13 +236,15 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
num_blocks
=
(
resnet_size
-
2
)
//
6
if
data_format
is
None
:
data_format
=
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
data_format
=
(
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
def
model
(
inputs
,
is_training
):
"""Constructs the ResNet model given the inputs."""
if
data_format
==
'channels_first'
:
# Convert from channels_last (NHWC) to channels_first (NCHW). This
# provides a large performance boost on GPU.
#
See
https://www.tensorflow.org/performance/performance_guide#data_formats
# provides a large performance boost on GPU.
See
# https://www.tensorflow.org/performance/performance_guide#data_formats
inputs
=
tf
.
transpose
(
inputs
,
[
0
,
3
,
1
,
2
])
inputs
=
conv2d_fixed_padding
(
...
...
@@ -294,9 +297,11 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
returns the output tensor of the ResNet model.
"""
if
data_format
is
None
:
data_format
=
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
data_format
=
(
'channels_first'
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
def
model
(
inputs
,
is_training
):
"""Constructs the ResNet model given the inputs."""
if
data_format
==
'channels_first'
:
# Convert from channels_last (NHWC) to channels_first (NCHW). This
# provides a large performance boost on GPU.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment