Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a6758929
Unverified
Commit
a6758929
authored
Feb 02, 2018
by
Karmel Allison
Committed by
GitHub
Feb 02, 2018
Browse files
Merge Resnet files (#3301)
parent
6c874e17
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
255 additions
and
272 deletions
+255
-272
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+13
-15
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+14
-15
official/resnet/resnet.py
official/resnet/resnet.py
+228
-1
official/resnet/resnet_shared.py
official/resnet/resnet_shared.py
+0
-241
No files found.
official/resnet/cifar10_main.py
View file @
a6758929
...
@@ -23,8 +23,7 @@ import sys
...
@@ -23,8 +23,7 @@ import sys
import
tensorflow
as
tf
import
tensorflow
as
tf
import
resnet_model
import
resnet
import
resnet_shared
_HEIGHT
=
32
_HEIGHT
=
32
_WIDTH
=
32
_WIDTH
=
32
...
@@ -152,8 +151,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -152,8 +151,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
###############################################################################
###############################################################################
# Running the model
# Running the model
###############################################################################
###############################################################################
class
Cifar10Model
(
resnet_model
.
Model
):
class
Cifar10Model
(
resnet
.
Model
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
"""These are the parameters that work for CIFAR-10 data.
"""These are the parameters that work for CIFAR-10 data.
"""
"""
...
@@ -172,7 +170,7 @@ class Cifar10Model(resnet_model.Model):
...
@@ -172,7 +170,7 @@ class Cifar10Model(resnet_model.Model):
first_pool_stride
=
None
,
first_pool_stride
=
None
,
second_pool_size
=
8
,
second_pool_size
=
8
,
second_pool_stride
=
1
,
second_pool_stride
=
1
,
block_fn
=
resnet
_model
.
building_block
,
block_fn
=
resnet
.
building_block
,
block_sizes
=
[
num_blocks
]
*
3
,
block_sizes
=
[
num_blocks
]
*
3
,
block_strides
=
[
1
,
2
,
2
],
block_strides
=
[
1
,
2
,
2
],
final_size
=
64
,
final_size
=
64
,
...
@@ -183,7 +181,7 @@ def cifar10_model_fn(features, labels, mode, params):
...
@@ -183,7 +181,7 @@ def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10."""
"""Model function for CIFAR-10."""
features
=
tf
.
reshape
(
features
,
[
-
1
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
features
=
tf
.
reshape
(
features
,
[
-
1
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
learning_rate_fn
=
resnet
_shared
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
100
,
150
,
200
],
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
100
,
150
,
200
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
...
@@ -200,23 +198,23 @@ def cifar10_model_fn(features, labels, mode, params):
...
@@ -200,23 +198,23 @@ def cifar10_model_fn(features, labels, mode, params):
def
loss_filter_fn
(
name
):
def
loss_filter_fn
(
name
):
return
True
return
True
return
resnet
_shared
.
resnet_model_fn
(
features
,
labels
,
mode
,
Cifar10Model
,
return
resnet
.
resnet_model_fn
(
features
,
labels
,
mode
,
Cifar10Model
,
resnet_size
=
params
[
'resnet_size'
],
resnet_size
=
params
[
'resnet_size'
],
weight_decay
=
weight_decay
,
weight_decay
=
weight_decay
,
learning_rate_fn
=
learning_rate_fn
,
learning_rate_fn
=
learning_rate_fn
,
momentum
=
0.9
,
momentum
=
0.9
,
data_format
=
params
[
'data_format'
],
data_format
=
params
[
'data_format'
],
loss_filter_fn
=
loss_filter_fn
)
loss_filter_fn
=
loss_filter_fn
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
resnet
_shared
.
resnet_main
(
FLAGS
,
cifar10_model_fn
,
input_fn
)
resnet
.
resnet_main
(
FLAGS
,
cifar10_model_fn
,
input_fn
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
parser
=
resnet
_shared
.
ResnetArgParser
()
parser
=
resnet
.
ResnetArgParser
()
# Set defaults that are reasonable for this model.
# Set defaults that are reasonable for this model.
parser
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data'
,
parser
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data'
,
model_dir
=
'/tmp/cifar10_model'
,
model_dir
=
'/tmp/cifar10_model'
,
...
...
official/resnet/imagenet_main.py
View file @
a6758929
...
@@ -23,8 +23,7 @@ import sys
...
@@ -23,8 +23,7 @@ import sys
import
tensorflow
as
tf
import
tensorflow
as
tf
import
resnet_model
import
resnet
import
resnet_shared
import
vgg_preprocessing
import
vgg_preprocessing
_DEFAULT_IMAGE_SIZE
=
224
_DEFAULT_IMAGE_SIZE
=
224
...
@@ -129,17 +128,17 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -129,17 +128,17 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
###############################################################################
###############################################################################
# Running the model
# Running the model
###############################################################################
###############################################################################
class
ImagenetModel
(
resnet
_model
.
Model
):
class
ImagenetModel
(
resnet
.
Model
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
def
__init__
(
self
,
resnet_size
,
data_format
=
None
):
"""These are the parameters that work for Imagenet data.
"""These are the parameters that work for Imagenet data.
"""
"""
# For bigger models, we want to use "bottleneck" layers
# For bigger models, we want to use "bottleneck" layers
if
resnet_size
<
50
:
if
resnet_size
<
50
:
block_fn
=
resnet
_model
.
building_block
block_fn
=
resnet
.
building_block
final_size
=
512
final_size
=
512
else
:
else
:
block_fn
=
resnet
_model
.
bottleneck_block
block_fn
=
resnet
.
bottleneck_block
final_size
=
2048
final_size
=
2048
super
(
ImagenetModel
,
self
).
__init__
(
super
(
ImagenetModel
,
self
).
__init__
(
...
@@ -184,28 +183,28 @@ def _get_block_sizes(resnet_size):
...
@@ -184,28 +183,28 @@ def _get_block_sizes(resnet_size):
def
imagenet_model_fn
(
features
,
labels
,
mode
,
params
):
def
imagenet_model_fn
(
features
,
labels
,
mode
,
params
):
"""Our model_fn for ResNet to be used with our Estimator."""
"""Our model_fn for ResNet to be used with our Estimator."""
learning_rate_fn
=
resnet
_shared
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
256
,
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
30
,
60
,
80
,
90
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
])
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
,
1e-4
])
return
resnet
_shared
.
resnet_model_fn
(
features
,
labels
,
mode
,
ImagenetModel
,
return
resnet
.
resnet_model_fn
(
features
,
labels
,
mode
,
ImagenetModel
,
resnet_size
=
params
[
'resnet_size'
],
resnet_size
=
params
[
'resnet_size'
],
weight_decay
=
1e-4
,
weight_decay
=
1e-4
,
learning_rate_fn
=
learning_rate_fn
,
learning_rate_fn
=
learning_rate_fn
,
momentum
=
0.9
,
momentum
=
0.9
,
data_format
=
params
[
'data_format'
],
data_format
=
params
[
'data_format'
],
loss_filter_fn
=
None
)
loss_filter_fn
=
None
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
resnet
_shared
.
resnet_main
(
FLAGS
,
imagenet_model_fn
,
input_fn
)
resnet
.
resnet_main
(
FLAGS
,
imagenet_model_fn
,
input_fn
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
parser
=
resnet
_shared
.
ResnetArgParser
(
parser
=
resnet
.
ResnetArgParser
(
resnet_size_choices
=
[
18
,
34
,
50
,
101
,
152
,
200
])
resnet_size_choices
=
[
18
,
34
,
50
,
101
,
152
,
200
])
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
tf
.
app
.
run
(
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
official/resnet/resnet
_model
.py
→
official/resnet/resnet.py
View file @
a6758929
...
@@ -12,7 +12,8 @@
...
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Contains definitions for the preactivation form of Residual Networks.
"""Contains definitions for the preactivation form of Residual Networks
(also known as ResNet v2).
Residual networks (ResNets) were originally proposed in:
Residual networks (ResNets) were originally proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
...
@@ -32,12 +33,18 @@ from __future__ import absolute_import
...
@@ -32,12 +33,18 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
os
import
tensorflow
as
tf
import
tensorflow
as
tf
_BATCH_NORM_DECAY
=
0.997
_BATCH_NORM_DECAY
=
0.997
_BATCH_NORM_EPSILON
=
1e-5
_BATCH_NORM_EPSILON
=
1e-5
################################################################################
# Functions building the ResNet model.
################################################################################
def
batch_norm_relu
(
inputs
,
training
,
data_format
):
def
batch_norm_relu
(
inputs
,
training
,
data_format
):
"""Performs a batch normalization followed by a ReLU."""
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant performance boost. See
# We set fused=True for a significant performance boost. See
...
@@ -318,3 +325,223 @@ class Model(object):
...
@@ -318,3 +325,223 @@ class Model(object):
inputs
=
tf
.
layers
.
dense
(
inputs
=
inputs
,
units
=
self
.
num_classes
)
inputs
=
tf
.
layers
.
dense
(
inputs
=
inputs
,
units
=
self
.
num_classes
)
inputs
=
tf
.
identity
(
inputs
,
'final_dense'
)
inputs
=
tf
.
identity
(
inputs
,
'final_dense'
)
return
inputs
return
inputs
################################################################################
# Functions for running training/eval/validation loops for the model.
################################################################################
def
learning_rate_with_decay
(
batch_size
,
batch_denom
,
num_images
,
boundary_epochs
,
decay_rates
):
"""Get a learning rate that decays step-wise as training progresses.
Args:
batch_size: the number of examples processed in each training batch.
batch_denom: this value will be used to scale the base learning rate.
`0.1 * batch size` is divided by this number, such that when
batch_denom == batch_size, the initial learning rate will be 0.1.
num_images: total number of images that will be used for training.
boundary_epochs: list of ints representing the epochs at which we
decay the learning rate.
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. Should be the same length as
boundary_epochs.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate
=
0.1
*
batch_size
/
batch_denom
batches_per_epoch
=
num_images
/
batch_size
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries
=
[
int
(
batches_per_epoch
*
epoch
)
for
epoch
in
boundary_epochs
]
vals
=
[
initial_learning_rate
*
decay
for
decay
in
decay_rates
]
def
learning_rate_fn
(
global_step
):
global_step
=
tf
.
cast
(
global_step
,
tf
.
int32
)
return
tf
.
train
.
piecewise_constant
(
global_step
,
boundaries
,
vals
)
return
learning_rate_fn
def
resnet_model_fn
(
features
,
labels
,
mode
,
model_class
,
resnet_size
,
weight_decay
,
learning_rate_fn
,
momentum
,
data_format
,
loss_filter_fn
=
None
):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
model
=
model_class
(
resnet_size
,
data_format
)
logits
=
model
(
features
,
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
predictions
=
{
'classes'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
,
name
=
'softmax_tensor'
)
}
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy
=
tf
.
losses
.
softmax_cross_entropy
(
logits
=
logits
,
onehot_labels
=
labels
)
# Create a tensor named cross_entropy for logging purposes.
tf
.
identity
(
cross_entropy
,
name
=
'cross_entropy'
)
tf
.
summary
.
scalar
(
'cross_entropy'
,
cross_entropy
)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
if
not
loss_filter_fn
:
def
loss_filter_fn
(
name
):
return
'batch_normalization'
not
in
name
# Add weight decay to the loss.
loss
=
cross_entropy
+
weight_decay
*
tf
.
add_n
(
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
tf
.
trainable_variables
()
if
loss_filter_fn
(
v
.
name
)])
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
global_step
=
tf
.
train
.
get_or_create_global_step
()
learning_rate
=
learning_rate_fn
(
global_step
)
# Create a tensor named learning_rate for logging purposes
tf
.
identity
(
learning_rate
,
name
=
'learning_rate'
)
tf
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
momentum
)
# Batch norm requires update ops to be added as a dependency to train_op
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
with
tf
.
control_dependencies
(
update_ops
):
train_op
=
optimizer
.
minimize
(
loss
,
global_step
)
else
:
train_op
=
None
accuracy
=
tf
.
metrics
.
accuracy
(
tf
.
argmax
(
labels
,
axis
=
1
),
predictions
[
'classes'
])
metrics
=
{
'accuracy'
:
accuracy
}
# Create a tensor named train_accuracy for logging purposes
tf
.
identity
(
accuracy
[
1
],
name
=
'train_accuracy'
)
tf
.
summary
.
scalar
(
'train_accuracy'
,
accuracy
[
1
])
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
loss
=
loss
,
train_op
=
train_op
,
eval_metric_ops
=
metrics
)
def
resnet_main
(
flags
,
model_function
,
input_function
):
# Using the Winograd non-fused algorithms provides a small performance boost.
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
)
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_function
,
model_dir
=
flags
.
model_dir
,
config
=
run_config
,
params
=
{
'resnet_size'
:
flags
.
resnet_size
,
'data_format'
:
flags
.
data_format
,
'batch_size'
:
flags
.
batch_size
,
})
for
_
in
range
(
flags
.
train_epochs
//
flags
.
epochs_per_eval
):
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'cross_entropy'
:
'cross_entropy'
,
'train_accuracy'
:
'train_accuracy'
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
print
(
'Starting a training cycle.'
)
classifier
.
train
(
input_fn
=
lambda
:
input_function
(
True
,
flags
.
data_dir
,
flags
.
batch_size
,
flags
.
epochs_per_eval
),
hooks
=
[
logging_hook
])
print
(
'Starting to evaluate.'
)
# Evaluate the model and print results
eval_results
=
classifier
.
evaluate
(
input_fn
=
lambda
:
input_function
(
False
,
flags
.
data_dir
,
flags
.
batch_size
))
print
(
eval_results
)
class
ResnetArgParser
(
argparse
.
ArgumentParser
):
"""Arguments for configuring and running a Resnet Model.
"""
def
__init__
(
self
,
resnet_size_choices
=
None
):
super
(
ResnetArgParser
,
self
).
__init__
()
self
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'/tmp/resnet_data'
,
help
=
'The directory where the input data is stored.'
)
self
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
'/tmp/resnet_model'
,
help
=
'The directory where the model will be stored.'
)
self
.
add_argument
(
'--resnet_size'
,
type
=
int
,
default
=
50
,
choices
=
resnet_size_choices
,
help
=
'The size of the ResNet model to use.'
)
self
.
add_argument
(
'--train_epochs'
,
type
=
int
,
default
=
100
,
help
=
'The number of epochs to use for training.'
)
self
.
add_argument
(
'--epochs_per_eval'
,
type
=
int
,
default
=
1
,
help
=
'The number of training epochs to run between evaluations.'
)
self
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'Batch size for training and evaluation.'
)
self
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
None
,
choices
=
[
'channels_first'
,
'channels_last'
],
help
=
'A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.'
)
official/resnet/resnet_shared.py
deleted
100644 → 0
View file @
6c874e17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for running Resnet that are shared across datasets."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
os
import
tensorflow
as
tf
def
learning_rate_with_decay
(
batch_size
,
batch_denom
,
num_images
,
boundary_epochs
,
decay_rates
):
"""Get a learning rate that decays step-wise as training progresses.
Args:
batch_size: the number of examples processed in each training batch.
batch_denom: this value will be used to scale the base learning rate.
`0.1 * batch size` is divided by this number, such that when
batch_denom == batch_size, the initial learning rate will be 0.1.
num_images: total number of images that will be used for training.
boundary_epochs: list of ints representing the epochs at which we
decay the learning rate.
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. Should be the same length as
boundary_epochs.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate
=
0.1
*
batch_size
/
batch_denom
batches_per_epoch
=
num_images
/
batch_size
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries
=
[
int
(
batches_per_epoch
*
epoch
)
for
epoch
in
boundary_epochs
]
vals
=
[
initial_learning_rate
*
decay
for
decay
in
decay_rates
]
def
learning_rate_fn
(
global_step
):
global_step
=
tf
.
cast
(
global_step
,
tf
.
int32
)
return
tf
.
train
.
piecewise_constant
(
global_step
,
boundaries
,
vals
)
return
learning_rate_fn
def
resnet_model_fn
(
features
,
labels
,
mode
,
model_class
,
resnet_size
,
weight_decay
,
learning_rate_fn
,
momentum
,
data_format
,
loss_filter_fn
=
None
):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
model
=
model_class
(
resnet_size
,
data_format
)
logits
=
model
(
features
,
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
predictions
=
{
'classes'
:
tf
.
argmax
(
logits
,
axis
=
1
),
'probabilities'
:
tf
.
nn
.
softmax
(
logits
,
name
=
'softmax_tensor'
)
}
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy
=
tf
.
losses
.
softmax_cross_entropy
(
logits
=
logits
,
onehot_labels
=
labels
)
# Create a tensor named cross_entropy for logging purposes.
tf
.
identity
(
cross_entropy
,
name
=
'cross_entropy'
)
tf
.
summary
.
scalar
(
'cross_entropy'
,
cross_entropy
)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
if
not
loss_filter_fn
:
def
loss_filter_fn
(
name
):
return
'batch_normalization'
not
in
name
# Add weight decay to the loss.
loss
=
cross_entropy
+
weight_decay
*
tf
.
add_n
(
[
tf
.
nn
.
l2_loss
(
v
)
for
v
in
tf
.
trainable_variables
()
if
loss_filter_fn
(
v
.
name
)])
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
global_step
=
tf
.
train
.
get_or_create_global_step
()
learning_rate
=
learning_rate_fn
(
global_step
)
# Create a tensor named learning_rate for logging purposes
tf
.
identity
(
learning_rate
,
name
=
'learning_rate'
)
tf
.
summary
.
scalar
(
'learning_rate'
,
learning_rate
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
momentum
)
# Batch norm requires update ops to be added as a dependency to train_op
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
with
tf
.
control_dependencies
(
update_ops
):
train_op
=
optimizer
.
minimize
(
loss
,
global_step
)
else
:
train_op
=
None
accuracy
=
tf
.
metrics
.
accuracy
(
tf
.
argmax
(
labels
,
axis
=
1
),
predictions
[
'classes'
])
metrics
=
{
'accuracy'
:
accuracy
}
# Create a tensor named train_accuracy for logging purposes
tf
.
identity
(
accuracy
[
1
],
name
=
'train_accuracy'
)
tf
.
summary
.
scalar
(
'train_accuracy'
,
accuracy
[
1
])
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
loss
=
loss
,
train_op
=
train_op
,
eval_metric_ops
=
metrics
)
def
resnet_main
(
flags
,
model_function
,
input_function
):
# Using the Winograd non-fused algorithms provides a small performance boost.
os
.
environ
[
'TF_ENABLE_WINOGRAD_NONFUSED'
]
=
'1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config
=
tf
.
estimator
.
RunConfig
().
replace
(
save_checkpoints_secs
=
1e9
)
classifier
=
tf
.
estimator
.
Estimator
(
model_fn
=
model_function
,
model_dir
=
flags
.
model_dir
,
config
=
run_config
,
params
=
{
'resnet_size'
:
flags
.
resnet_size
,
'data_format'
:
flags
.
data_format
,
'batch_size'
:
flags
.
batch_size
,
})
for
_
in
range
(
flags
.
train_epochs
//
flags
.
epochs_per_eval
):
tensors_to_log
=
{
'learning_rate'
:
'learning_rate'
,
'cross_entropy'
:
'cross_entropy'
,
'train_accuracy'
:
'train_accuracy'
}
logging_hook
=
tf
.
train
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
100
)
print
(
'Starting a training cycle.'
)
classifier
.
train
(
input_fn
=
lambda
:
input_function
(
True
,
flags
.
data_dir
,
flags
.
batch_size
,
flags
.
epochs_per_eval
),
hooks
=
[
logging_hook
])
print
(
'Starting to evaluate.'
)
# Evaluate the model and print results
eval_results
=
classifier
.
evaluate
(
input_fn
=
lambda
:
input_function
(
False
,
flags
.
data_dir
,
flags
.
batch_size
))
print
(
eval_results
)
class
ResnetArgParser
(
argparse
.
ArgumentParser
):
"""Arguments for configuring and running a Resnet Model.
"""
def
__init__
(
self
,
resnet_size_choices
=
None
):
super
(
ResnetArgParser
,
self
).
__init__
()
self
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'/tmp/resnet_data'
,
help
=
'The directory where the input data is stored.'
)
self
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
'/tmp/resnet_model'
,
help
=
'The directory where the model will be stored.'
)
self
.
add_argument
(
'--resnet_size'
,
type
=
int
,
default
=
50
,
choices
=
resnet_size_choices
,
help
=
'The size of the ResNet model to use.'
)
self
.
add_argument
(
'--train_epochs'
,
type
=
int
,
default
=
100
,
help
=
'The number of epochs to use for training.'
)
self
.
add_argument
(
'--epochs_per_eval'
,
type
=
int
,
default
=
1
,
help
=
'The number of training epochs to run between evaluations.'
)
self
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'Batch size for training and evaluation.'
)
self
.
add_argument
(
'--data_format'
,
type
=
str
,
default
=
None
,
choices
=
[
'channels_first'
,
'channels_last'
],
help
=
'A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment