Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7cfb6bbd
Unverified
Commit
7cfb6bbd
authored
Mar 20, 2018
by
Karmel Allison
Committed by
GitHub
Mar 20, 2018
Browse files
Glint everything (#3654)
* Glint everything * Adding rcfile and pylinting * Extra newline * Few last lints
parent
adfd5a3a
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
128 deletions
+183
-128
official/__init__.py
official/__init__.py
+0
-14
official/mnist/dataset.py
official/mnist/dataset.py
+5
-3
official/mnist/mnist.py
official/mnist/mnist.py
+21
-12
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+10
-9
official/mnist/mnist_eager_test.py
official/mnist/mnist_eager_test.py
+3
-2
official/mnist/mnist_test.py
official/mnist/mnist_test.py
+5
-2
official/mnist/mnist_tpu.py
official/mnist/mnist_tpu.py
+3
-2
official/resnet/cifar10_download_and_extract.py
official/resnet/cifar10_download_and_extract.py
+1
-1
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+13
-7
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+9
-5
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+17
-4
official/resnet/imagenet_preprocessing.py
official/resnet/imagenet_preprocessing.py
+5
-2
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+6
-7
official/resnet/resnet_model.py
official/resnet/resnet_model.py
+46
-11
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+25
-19
official/utils/__init__.py
official/utils/__init__.py
+0
-14
official/utils/arg_parsers/parsers.py
official/utils/arg_parsers/parsers.py
+3
-3
official/utils/logging/hooks_helper.py
official/utils/logging/hooks_helper.py
+5
-6
official/utils/logging/hooks_helper_test.py
official/utils/logging/hooks_helper_test.py
+1
-1
official/utils/logging/hooks_test.py
official/utils/logging/hooks_test.py
+5
-4
No files found.
official/__init__.py
View file @
7cfb6bbd
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
official/mnist/dataset.py
View file @
7cfb6bbd
...
...
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
gzip
import
os
import
shutil
import
gzip
import
numpy
as
np
from
six.moves
import
urllib
...
...
@@ -36,7 +36,7 @@ def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with
tf
.
gfile
.
Open
(
filename
,
'rb'
)
as
f
:
magic
=
read32
(
f
)
num_images
=
read32
(
f
)
read32
(
f
)
# num_images, unused
rows
=
read32
(
f
)
cols
=
read32
(
f
)
if
magic
!=
2051
:
...
...
@@ -52,7 +52,7 @@ def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with
tf
.
gfile
.
Open
(
filename
,
'rb'
)
as
f
:
magic
=
read32
(
f
)
num_items
=
read32
(
f
)
read32
(
f
)
# num_items, unused
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number %d in MNIST file %s'
%
(
magic
,
f
.
name
))
...
...
@@ -77,6 +77,8 @@ def download(directory, filename):
def
dataset
(
directory
,
images_file
,
labels_file
):
"""Download and parse MNIST dataset."""
images_file
=
download
(
directory
,
images_file
)
labels_file
=
download
(
directory
,
labels_file
)
...
...
official/mnist/mnist.py
View file @
7cfb6bbd
...
...
@@ -20,7 +20,7 @@ from __future__ import print_function
import
argparse
import
sys
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.mnist
import
dataset
from
official.utils.arg_parsers
import
parsers
...
...
@@ -28,6 +28,7 @@ from official.utils.logging import hooks_helper
LEARNING_RATE
=
1e-4
class
Model
(
tf
.
keras
.
Model
):
"""Model to recognize digits in the MNIST dataset.
...
...
@@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):
def
validate_batch_size_for_multi_gpu
(
batch_size
):
"""For multi-gpu, batch-size must be a multiple of the number of
available GPUs.
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
Args:
batch_size: the number of examples processed in each training batch.
Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
"""
from
tensorflow.python.client
import
device_lib
from
tensorflow.python.client
import
device_lib
# pylint: disable=g-import-not-at-top
local_device_protos
=
device_lib
.
list_local_devices
()
num_gpus
=
sum
([
1
for
d
in
local_device_protos
if
d
.
device_type
==
'GPU'
])
if
not
num_gpus
:
raise
ValueError
(
'Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.'
)
'were found. To use CPU, run without --multi_gpu.'
)
remainder
=
batch_size
%
num_gpus
if
remainder
:
err
=
(
'When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).
format
(
num_gpus
,
batch_size
,
batch_size
-
remainder
)
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).
format
(
num_gpus
,
batch_size
,
batch_size
-
remainder
)
raise
ValueError
(
err
)
def
main
(
unused_argv
):
def
main
(
_
):
model_function
=
model_fn
if
FLAGS
.
multi_gpu
:
...
...
@@ -195,6 +201,8 @@ def main(unused_argv):
# Set up training and evaluation input functions.
def
train_input_fn
():
"""Prepare data for training."""
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
...
...
@@ -215,7 +223,7 @@ def main(unused_argv):
FLAGS
.
hooks
,
batch_size
=
FLAGS
.
batch_size
)
# Train and evaluate model.
for
n
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_between_evals
):
for
_
in
range
(
FLAGS
.
train_epochs
//
FLAGS
.
epochs_between_evals
):
mnist_classifier
.
train
(
input_fn
=
train_input_fn
,
hooks
=
train_hooks
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
...
...
@@ -231,10 +239,11 @@ def main(unused_argv):
class
MNISTArgParser
(
argparse
.
ArgumentParser
):
"""Argument parser for running MNIST model."""
def
__init__
(
self
):
super
(
MNISTArgParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
(),
parsers
.
ImageModelParser
()])
parsers
.
BaseParser
(),
parsers
.
ImageModelParser
()])
self
.
add_argument
(
'--export_dir'
,
...
...
official/mnist/mnist_eager.py
View file @
7cfb6bbd
...
...
@@ -31,11 +31,11 @@ import os
import
sys
import
time
import
tensorflow
as
tf
import
tensorflow.contrib.eager
as
tfe
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow.contrib.eager
as
tfe
# pylint: disable=g-bad-import-order
from
official.mnist
import
dataset
as
mnist_dataset
from
official.mnist
import
mnist
from
official.mnist
import
dataset
from
official.utils.arg_parsers
import
parsers
FLAGS
=
None
...
...
@@ -110,9 +110,9 @@ def main(_):
print
(
'Using device %s, and data format %s.'
%
(
device
,
data_format
))
# Load the datasets
train_ds
=
dataset
.
train
(
FLAGS
.
data_dir
).
shuffle
(
60000
).
batch
(
train_ds
=
mnist_
dataset
.
train
(
FLAGS
.
data_dir
).
shuffle
(
60000
).
batch
(
FLAGS
.
batch_size
)
test_ds
=
dataset
.
test
(
FLAGS
.
data_dir
).
batch
(
FLAGS
.
batch_size
)
test_ds
=
mnist_
dataset
.
test
(
FLAGS
.
data_dir
).
batch
(
FLAGS
.
batch_size
)
# Create the model and optimizer
model
=
mnist
.
Model
(
data_format
)
...
...
@@ -159,12 +159,13 @@ def main(_):
class
MNISTEagerArgParser
(
argparse
.
ArgumentParser
):
"""Argument parser for running MNIST model with eager trainng loop."""
"""Argument parser for running MNIST model with eager training loop."""
def
__init__
(
self
):
super
(
MNISTEagerArgParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
(
epochs_between_evals
=
False
,
multi_gpu
=
False
,
hooks
=
False
),
parsers
.
ImageModelParser
()])
parsers
.
BaseParser
(
epochs_between_evals
=
False
,
multi_gpu
=
False
,
hooks
=
False
),
parsers
.
ImageModelParser
()])
self
.
add_argument
(
'--log_interval'
,
'-li'
,
...
...
official/mnist/mnist_eager_test.py
View file @
7cfb6bbd
...
...
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow.contrib.eager
as
tfe
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow.contrib.eager
as
tfe
# pylint: disable=g-bad-import-order
from
official.mnist
import
mnist
from
official.mnist
import
mnist_eager
...
...
@@ -60,6 +60,7 @@ def evaluate(defun=False):
class
MNISTTest
(
tf
.
test
.
TestCase
):
"""Run tests for MNIST eager loop."""
def
test_train
(
self
):
train
(
defun
=
False
)
...
...
official/mnist/mnist_test.py
View file @
7cfb6bbd
...
...
@@ -17,9 +17,10 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
time
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.mnist
import
mnist
BATCH_SIZE
=
100
...
...
@@ -42,6 +43,7 @@ def make_estimator():
class
Tests
(
tf
.
test
.
TestCase
):
"""Run tests for MNIST model."""
def
test_mnist
(
self
):
classifier
=
make_estimator
()
...
...
@@ -57,7 +59,7 @@ class Tests(tf.test.TestCase):
input_fn
=
lambda
:
tf
.
random_uniform
([
3
,
784
])
predictions_generator
=
classifier
.
predict
(
input_fn
)
for
i
in
range
(
3
):
for
_
in
range
(
3
):
predictions
=
next
(
predictions_generator
)
self
.
assertEqual
(
predictions
[
'probabilities'
].
shape
,
(
10
,))
self
.
assertEqual
(
predictions
[
'classes'
].
shape
,
())
...
...
@@ -103,6 +105,7 @@ class Tests(tf.test.TestCase):
class
Benchmarks
(
tf
.
test
.
Benchmark
):
"""Simple speed benchmarking for MNIST."""
def
benchmark_train_step_time
(
self
):
classifier
=
make_estimator
()
...
...
official/mnist/mnist_tpu.py
View file @
7cfb6bbd
...
...
@@ -23,7 +23,8 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.mnist
import
dataset
from
official.mnist
import
mnist
...
...
@@ -132,7 +133,7 @@ def main(argv):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
FLAGS
.
tpu
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
...
...
official/resnet/cifar10_download_and_extract.py
View file @
7cfb6bbd
...
...
@@ -36,7 +36,7 @@ parser.add_argument(
help
=
'Directory to download data and extract the tarball'
)
def
main
(
unused_argv
):
def
main
(
_
):
"""Download and extract the tarball from Alex's website."""
if
not
os
.
path
.
exists
(
FLAGS
.
data_dir
):
os
.
makedirs
(
FLAGS
.
data_dir
)
...
...
official/resnet/cifar10_main.py
View file @
7cfb6bbd
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
os
import
sys
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
resnet_model
from
official.resnet
import
resnet_run_loop
...
...
@@ -127,22 +127,25 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_images
=
is_training
and
_NUM_IMAGES
[
'train'
]
or
_NUM_IMAGES
[
'validation'
]
return
resnet_run_loop
.
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
_NUM_IMAGES
[
'train'
],
parse_record
,
num_epochs
,
num_parallel_calls
,
examples_per_epoch
=
num_images
,
multi_gpu
=
multi_gpu
)
return
resnet_run_loop
.
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
_NUM_IMAGES
[
'train'
],
parse_record
,
num_epochs
,
num_parallel_calls
,
examples_per_epoch
=
num_images
,
multi_gpu
=
multi_gpu
)
def
get_synth_input_fn
():
return
resnet_run_loop
.
get_synth_input_fn
(
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
return
resnet_run_loop
.
get_synth_input_fn
(
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
###############################################################################
# Running the model
###############################################################################
class
Cifar10Model
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for CIFAR-10 data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_NUM_CLASSES
,
version
=
resnet_model
.
DEFAULT_VERSION
):
version
=
resnet_model
.
DEFAULT_VERSION
):
"""These are the parameters that work for CIFAR-10 data.
Args:
...
...
@@ -153,6 +156,9 @@ class Cifar10Model(resnet_model.Model):
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
Raises:
ValueError: if invalid resnet_size is chosen
"""
if
resnet_size
%
6
!=
2
:
raise
ValueError
(
'resnet_size must be 6n + 2:'
,
resnet_size
)
...
...
@@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params):
# for the CIFAR-10 dataset, perhaps because the regularization prevents
# overfitting on the small data set. We therefore include all vars when
# regularizing and computing loss during training.
def
loss_filter_fn
(
name
):
def
loss_filter_fn
(
_
):
return
True
return
resnet_run_loop
.
resnet_model_fn
(
features
,
labels
,
mode
,
Cifar10Model
,
...
...
official/resnet/cifar10_test.py
View file @
7cfb6bbd
...
...
@@ -20,7 +20,7 @@ from __future__ import print_function
from
tempfile
import
mkstemp
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
from
official.utils.testing
import
integration
...
...
@@ -34,6 +34,8 @@ _NUM_CHANNELS = 3
class
BaseTest
(
tf
.
test
.
TestCase
):
"""Tests for the Cifar10 version of Resnet.
"""
def
tearDown
(
self
):
super
(
BaseTest
,
self
).
tearDown
()
...
...
@@ -52,7 +54,7 @@ class BaseTest(tf.test.TestCase):
data_file
.
close
()
fake_dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filename
,
cifar10_main
.
_RECORD_BYTES
)
filename
,
cifar10_main
.
_RECORD_BYTES
)
# pylint: disable=protected-access
fake_dataset
=
fake_dataset
.
map
(
lambda
val
:
cifar10_main
.
parse_record
(
val
,
False
))
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
...
...
@@ -133,9 +135,11 @@ class BaseTest(tf.test.TestCase):
num_classes
=
246
for
version
in
(
1
,
2
):
model
=
cifar10_main
.
Cifar10Model
(
32
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
,
version
=
version
)
fake_input
=
tf
.
random_uniform
([
batch_size
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
model
=
cifar10_main
.
Cifar10Model
(
32
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
,
version
=
version
)
fake_input
=
tf
.
random_uniform
(
[
batch_size
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
output
=
model
(
fake_input
,
training
=
True
)
self
.
assertAllEqual
(
output
.
shape
,
(
batch_size
,
num_classes
))
...
...
official/resnet/imagenet_main.py
View file @
7cfb6bbd
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
os
import
sys
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_preprocessing
from
official.resnet
import
resnet_model
...
...
@@ -157,6 +157,7 @@ def parse_record(raw_record, is_training):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_parallel_calls
=
1
,
multi_gpu
=
False
):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
...
...
@@ -192,16 +193,17 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def
get_synth_input_fn
():
return
resnet_run_loop
.
get_synth_input_fn
(
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
###############################################################################
# Running the model
###############################################################################
class
ImagenetModel
(
resnet_model
.
Model
):
"""Model class with appropriate defaults for Imagenet data."""
def
__init__
(
self
,
resnet_size
,
data_format
=
None
,
num_classes
=
_NUM_CLASSES
,
version
=
resnet_model
.
DEFAULT_VERSION
):
version
=
resnet_model
.
DEFAULT_VERSION
):
"""These are the parameters that work for Imagenet data.
Args:
...
...
@@ -241,9 +243,20 @@ class ImagenetModel(resnet_model.Model):
def
_get_block_sizes
(
resnet_size
):
"""The number of block layers used for the Resnet model varies according
"""Retrieve the size of each block_layer in the ResNet model.
The number of block layers used for the Resnet model varies according
to the size of the model. This helper grabs the layer set we want, throwing
an error if a non-standard size has been selected.
Args:
resnet_size: The number of convolutional layers needed in the model.
Returns:
A list of block sizes to use in building the model.
Raises:
KeyError: if invalid resnet_size is received.
"""
choices
=
{
18
:
[
2
,
2
,
2
,
2
],
...
...
official/resnet/imagenet_preprocessing.py
View file @
7cfb6bbd
...
...
@@ -204,8 +204,10 @@ def _aspect_preserving_resize(image, resize_min):
def
_resize_image
(
image
,
height
,
width
):
"""Simple wrapper around tf.resize_images to make sure we use the same
`method` and other details each time.
"""Simple wrapper around tf.resize_images.
This is primarily to make sure we use the same `ResizeMethod` and other
details each time.
Args:
image: A 3-D image `Tensor`.
...
...
@@ -220,6 +222,7 @@ def _resize_image(image, height, width):
image
,
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
,
align_corners
=
False
)
def
preprocess_image
(
image_buffer
,
bbox
,
output_height
,
output_width
,
num_channels
,
is_training
=
False
):
"""Preprocesses the given image.
...
...
official/resnet/imagenet_test.py
View file @
7cfb6bbd
...
...
@@ -19,7 +19,7 @@ from __future__ import print_function
import
unittest
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.resnet
import
imagenet_main
from
official.utils.testing
import
integration
...
...
@@ -39,9 +39,8 @@ class BaseTest(tf.test.TestCase):
def
tensor_shapes_helper
(
self
,
resnet_size
,
version
,
with_gpu
=
False
):
"""Checks the tensor shapes after each phase of the ResNet model."""
def
reshape
(
shape
):
"""Returns the expected dimensions depending on if a
GPU is being used.
"""
"""Returns the expected dimensions depending on if a GPU is being used."""
# If a GPU is used for the test, the shape is returned (already in NCHW
# form). When GPU is not used, the shape is converted to NHWC.
if
with_gpu
:
...
...
@@ -240,8 +239,9 @@ class BaseTest(tf.test.TestCase):
num_classes
=
246
for
version
in
(
1
,
2
):
model
=
imagenet_main
.
ImagenetModel
(
50
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
,
version
=
version
)
model
=
imagenet_main
.
ImagenetModel
(
50
,
data_format
=
'channels_last'
,
num_classes
=
num_classes
,
version
=
version
)
fake_input
=
tf
.
random_uniform
([
batch_size
,
224
,
224
,
3
])
output
=
model
(
fake_input
,
training
=
True
)
...
...
@@ -285,4 +285,3 @@ class BaseTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/resnet/resnet_model.py
View file @
7cfb6bbd
...
...
@@ -99,7 +99,8 @@ def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
################################################################################
def
_building_block_v1
(
inputs
,
filters
,
training
,
projection_shortcut
,
strides
,
data_format
):
"""
"""A single block for ResNet v1, without a bottleneck.
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
...
...
@@ -118,7 +119,7 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block.
The output tensor of the block
; shape should match inputs
.
"""
shortcut
=
inputs
...
...
@@ -145,7 +146,8 @@ def _building_block_v1(inputs, filters, training, projection_shortcut, strides,
def
_building_block_v2
(
inputs
,
filters
,
training
,
projection_shortcut
,
strides
,
data_format
):
"""
"""A single block for ResNet v2, without a bottleneck.
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
...
...
@@ -164,7 +166,7 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block.
The output tensor of the block
; shape should match inputs
.
"""
shortcut
=
inputs
inputs
=
batch_norm
(
inputs
,
training
,
data_format
)
...
...
@@ -190,13 +192,29 @@ def _building_block_v2(inputs, filters, training, projection_shortcut, strides,
def
_bottleneck_block_v1
(
inputs
,
filters
,
training
,
projection_shortcut
,
strides
,
data_format
):
"""
"""A single block for ResNet v1, with a bottleneck.
Similar to _building_block_v1(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
Deep Residual Learning for Image Recognition
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut
=
inputs
...
...
@@ -229,7 +247,8 @@ def _bottleneck_block_v1(inputs, filters, training, projection_shortcut,
def
_bottleneck_block_v2
(
inputs
,
filters
,
training
,
projection_shortcut
,
strides
,
data_format
):
"""
"""A single block for ResNet v2, without a bottleneck.
Similar to _building_block_v2(), except using the "bottleneck" blocks
described in:
Convolution then batch normalization then ReLU as described by:
...
...
@@ -237,11 +256,26 @@ def _bottleneck_block_v2(inputs, filters, training, projection_shortcut,
https://arxiv.org/pdf/1512.03385.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
a
dapted to the ordering conventions of:
A
dapted to the ordering conventions of:
Batch normalization then ReLu then convolution as described by:
Identity Mappings in Deep Residual Networks
https://arxiv.org/pdf/1603.05027.pdf
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
Args:
inputs: A tensor of size [batch, channels, height_in, width_in] or
[batch, height_in, width_in, channels] depending on data_format.
filters: The number of filters for the convolutions.
training: A Boolean for whether the model is in training or inference
mode. Needed for batch normalization.
projection_shortcut: The function to use for projection shortcuts
(typically a 1x1 convolution when downsampling the input).
strides: The block's stride. If greater than 1, this block will ultimately
downsample the input.
data_format: The input format ('channels_last' or 'channels_first').
Returns:
The output tensor of the block; shape should match inputs.
"""
shortcut
=
inputs
inputs
=
batch_norm
(
inputs
,
training
,
data_format
)
...
...
@@ -313,8 +347,7 @@ def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides,
class
Model
(
object
):
"""Base class for building the Resnet Model.
"""
"""Base class for building the Resnet Model."""
def
__init__
(
self
,
resnet_size
,
bottleneck
,
num_classes
,
num_filters
,
kernel_size
,
...
...
@@ -348,6 +381,9 @@ class Model(object):
See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
Raises:
ValueError: if invalid version is selected.
"""
self
.
resnet_size
=
resnet_size
...
...
@@ -358,7 +394,7 @@ class Model(object):
self
.
resnet_version
=
version
if
version
not
in
(
1
,
2
):
raise
ValueError
(
"
Resnet version should be 1 or 2. See README for citations.
"
)
'
Resnet version should be 1 or 2. See README for citations.
'
)
self
.
bottleneck
=
bottleneck
if
bottleneck
:
...
...
@@ -435,4 +471,3 @@ class Model(object):
inputs
=
tf
.
layers
.
dense
(
inputs
=
inputs
,
units
=
self
.
num_classes
)
inputs
=
tf
.
identity
(
inputs
,
'final_dense'
)
return
inputs
official/resnet/resnet_run_loop.py
View file @
7cfb6bbd
...
...
@@ -26,11 +26,11 @@ from __future__ import print_function
import
argparse
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.arg_parsers
import
parsers
# pylint: disable=g-bad-import-order
from
official.utils.logging
import
hooks_helper
from
official.resnet
import
resnet_model
from
official.utils.arg_parsers
import
parsers
from
official.utils.logging
import
hooks_helper
################################################################################
...
...
@@ -39,8 +39,7 @@ from official.resnet import resnet_model
def
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
shuffle_buffer
,
parse_record_fn
,
num_epochs
=
1
,
num_parallel_calls
=
1
,
examples_per_epoch
=
0
,
multi_gpu
=
False
):
"""Given a Dataset with raw records, parse each record into images and labels,
and return an iterator over the records.
"""Given a Dataset with raw records, return an iterator over the records.
Args:
dataset: A Dataset representing raw records
...
...
@@ -121,7 +120,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
):
# pylint: disable=unused-argument
images
=
tf
.
zeros
((
batch_size
,
height
,
width
,
num_channels
),
tf
.
float32
)
labels
=
tf
.
zeros
((
batch_size
,
num_classes
),
tf
.
int32
)
return
tf
.
data
.
Dataset
.
from_tensors
((
images
,
labels
)).
repeat
()
...
...
@@ -231,9 +230,9 @@ def resnet_model_fn(features, labels, mode, model_class,
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
if
not
loss_filter_fn
:
def
loss_filter_fn
(
name
):
return
'batch_normalization'
not
in
name
def
exclude_batch_norm
(
name
)
:
return
'batch_normalization'
not
in
name
loss_filter_fn
=
loss_filter_fn
or
exclude_batch_norm
# Add weight decay to the loss.
loss
=
cross_entropy
+
weight_decay
*
tf
.
add_n
(
...
...
@@ -279,31 +278,38 @@ def resnet_model_fn(features, labels, mode, model_class,
def
validate_batch_size_for_multi_gpu
(
batch_size
):
"""For multi-gpu, batch-size must be a multiple of the number of
available GPUs.
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
Args:
batch_size: the number of examples processed in each training batch.
Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
"""
from
tensorflow.python.client
import
device_lib
from
tensorflow.python.client
import
device_lib
# pylint: disable=g-import-not-at-top
local_device_protos
=
device_lib
.
list_local_devices
()
num_gpus
=
sum
([
1
for
d
in
local_device_protos
if
d
.
device_type
==
'GPU'
])
if
not
num_gpus
:
raise
ValueError
(
'Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.'
)
'were found. To use CPU, run without --multi_gpu.'
)
remainder
=
batch_size
%
num_gpus
if
remainder
:
err
=
(
'When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).
format
(
num_gpus
,
batch_size
,
batch_size
-
remainder
)
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).
format
(
num_gpus
,
batch_size
,
batch_size
-
remainder
)
raise
ValueError
(
err
)
def
resnet_main
(
flags
,
model_function
,
input_function
):
"""Shared main loop for ResNet Models."""
# Using the Winograd non-fused algorithms provides a small performance boost.
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
...
...
@@ -340,8 +346,8 @@ def resnet_main(flags, model_function, input_function):
})
for
_
in
range
(
flags
.
train_epochs
//
flags
.
epochs_between_evals
):
train_hooks
=
hooks_helper
.
get_train_hooks
(
flags
.
hooks
,
batch_size
=
flags
.
batch_size
)
train_hooks
=
hooks_helper
.
get_train_hooks
(
flags
.
hooks
,
batch_size
=
flags
.
batch_size
)
print
(
'Starting a training cycle.'
)
...
...
@@ -384,7 +390,7 @@ class ResnetArgParser(argparse.ArgumentParser):
self
.
add_argument
(
'--version'
,
'-v'
,
type
=
int
,
choices
=
[
1
,
2
],
default
=
resnet_model
.
DEFAULT_VERSION
,
help
=
"
Version of ResNet. (1 or 2) See README.md for details.
"
help
=
'
Version of ResNet. (1 or 2) See README.md for details.
'
)
self
.
add_argument
(
...
...
official/utils/__init__.py
View file @
7cfb6bbd
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
official/utils/arg_parsers/parsers.py
View file @
7cfb6bbd
...
...
@@ -46,11 +46,11 @@ Notes about add_argument():
The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than:
--app
lication
_specific_arg APP
LICATION
_SPECIFIC_ARG, -asa APP
LICATION
_SPECIFIC_ARG
--app_specific_arg APP_SPECIFIC_ARG, -asa APP_SPECIFIC_ARG
if metavar="<ASA>" is set, the user sees:
--app
lication
_specific_arg <ASA>, -asa <ASA>
--app_specific_arg <ASA>, -asa <ASA>
"""
...
...
@@ -216,7 +216,7 @@ class ImageModelParser(argparse.ArgumentParser):
self
.
add_argument
(
"--data_format"
,
"-df"
,
default
=
None
,
choices
=
[
'
channels_first
'
,
'
channels_last
'
],
choices
=
[
"
channels_first
"
,
"
channels_last
"
],
help
=
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data "
...
...
official/utils/logging/hooks_helper.py
View file @
7cfb6bbd
...
...
@@ -24,7 +24,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.logging
import
hooks
...
...
@@ -40,7 +40,7 @@ def get_train_hooks(name_list, **kwargs):
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
kwargs: a dictionary of arguments to the hooks.
**
kwargs: a dictionary of arguments to the hooks.
Returns:
list of instantiated hooks, ready to be used in a classifier.train call.
...
...
@@ -71,7 +71,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
kwargs: a dictionary of arguments to LoggingTensorHook.
**
kwargs: a dictionary of arguments to LoggingTensorHook.
Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be
...
...
@@ -90,7 +90,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu
Args:
save_steps: `int`, print profile traces every N steps.
kwargs: a dictionary of arguments to ProfilerHook.
**
kwargs: a dictionary of arguments to ProfilerHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
...
...
@@ -111,7 +111,7 @@ def get_examples_per_second_hook(every_n_steps=100,
batch_size: `int`, total batch size used to calculate examples/second from
global time.
warm_steps: skip this number of steps before logging and running average.
kwargs: a dictionary of arguments to ExamplesPerSecondHook.
**
kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
...
...
@@ -128,4 +128,3 @@ HOOKS = {
'profilerhook'
:
get_profiler_hook
,
'examplespersecondhook'
:
get_examples_per_second_hook
,
}
official/utils/logging/hooks_helper_test.py
View file @
7cfb6bbd
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
unittest
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.logging
import
hooks_helper
...
...
official/utils/logging/hooks_test.py
View file @
7cfb6bbd
...
...
@@ -21,9 +21,9 @@ from __future__ import print_function
import
time
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tensorflow.python.training
import
monitored_session
# pylint: disable=g-bad-import-order
from
tensorflow.python.training
import
monitored_session
from
official.utils.logging
import
hooks
...
...
@@ -31,6 +31,7 @@ tf.logging.set_verbosity(tf.logging.ERROR)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
"""Tests for the ExamplesPerSecondHook."""
def
setUp
(
self
):
"""Mock out logging calls to verify if correct info is being monitored."""
...
...
@@ -71,7 +72,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_steps
=
every_n_steps
,
warm_steps
=
warm_steps
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
self
.
logged_message
=
''
...
...
@@ -120,7 +121,7 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_steps
=
None
,
every_n_secs
=
every_n_secs
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
self
.
logged_message
=
''
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment