Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a5c4fd06
Commit
a5c4fd06
authored
Aug 26, 2016
by
nathansilberman
Committed by
Martin Wicke
Aug 26, 2016
Browse files
Initial tf-slim checkin (#349)
parent
a5594334
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1045 additions
and
0 deletions
+1045
-0
slim/models/vgg_preprocessing.py
slim/models/vgg_preprocessing.py
+370
-0
slim/nets/lenet.py
slim/nets/lenet.py
+92
-0
slim/scripts/train_lenet_on_mnist.sh
slim/scripts/train_lenet_on_mnist.sh
+43
-0
slim/train.py
slim/train.py
+540
-0
No files found.
slim/models/vgg_preprocessing.py
0 → 100644
View file @
a5c4fd06
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides utilities to preprocess images.
The preprocessing steps for VGG were introduced in the following technical
report:
Very Deep Convolutional Networks For Large-Scale Image Recognition
Karen Simonyan and Andrew Zisserman
arXiv technical report, 2015
PDF: http://arxiv.org/pdf/1409.1556.pdf
ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
CC-BY-4.0
More information can be obtained from the VGG website:
www.robots.ox.ac.uk/~vgg/research/very_deep/
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.ops
import
control_flow_ops
slim
=
tf
.
contrib
.
slim
_R_MEAN
=
123.68
_G_MEAN
=
116.78
_B_MEAN
=
103.94
_RESIZE_SIDE_MIN
=
256
_RESIZE_SIDE_MAX
=
512
def
_crop
(
image
,
offset_height
,
offset_width
,
crop_height
,
crop_width
):
"""Crops the given image using the provided offsets and sizes.
Note that the method doesn't assume we know the input image size but it does
assume we know the input image rank.
Args:
image: an image of shape [height, width, channels].
offset_height: a scalar tensor indicating the height offset.
offset_width: a scalar tensor indicating the width offset.
crop_height: the height of the cropped image.
crop_width: the width of the cropped image.
Returns:
the cropped (and resized) image.
Raises:
InvalidArgumentError: if the rank is not 3 or if the image dimensions are
less than the crop size.
"""
original_shape
=
tf
.
shape
(
image
)
rank_assertion
=
tf
.
Assert
(
tf
.
equal
(
tf
.
rank
(
image
),
3
),
[
'Rank of image must be equal to 3.'
])
cropped_shape
=
control_flow_ops
.
with_dependencies
(
[
rank_assertion
],
tf
.
pack
([
crop_height
,
crop_width
,
original_shape
[
2
]]))
size_assertion
=
tf
.
Assert
(
tf
.
logical_and
(
tf
.
greater_equal
(
original_shape
[
0
],
crop_height
),
tf
.
greater_equal
(
original_shape
[
1
],
crop_width
)),
[
'Crop size greater than the image size.'
])
offsets
=
tf
.
to_int32
(
tf
.
pack
([
offset_height
,
offset_width
,
0
]))
# Use tf.slice instead of crop_to_bounding box as it accepts tensors to
# define the crop size.
image
=
control_flow_ops
.
with_dependencies
(
[
size_assertion
],
tf
.
slice
(
image
,
offsets
,
cropped_shape
))
return
tf
.
reshape
(
image
,
cropped_shape
)
def
_random_crop
(
image_list
,
crop_height
,
crop_width
):
"""Crops the given list of images.
The function applies the same crop to each image in the list. This can be
effectively applied when there are multiple image inputs of the same
dimension such as:
image, depths, normals = _random_crop([image, depths, normals], 120, 150)
Args:
image_list: a list of image tensors of the same dimension but possibly
varying channel.
crop_height: the new height.
crop_width: the new width.
Returns:
the image_list with cropped images.
Raises:
ValueError: if there are multiple image inputs provided with different size
or the images are smaller than the crop dimensions.
"""
if
not
image_list
:
raise
ValueError
(
'Empty image_list.'
)
# Compute the rank assertions.
rank_assertions
=
[]
for
i
in
range
(
len
(
image_list
)):
image_rank
=
tf
.
rank
(
image_list
[
i
])
rank_assert
=
tf
.
Assert
(
tf
.
equal
(
image_rank
,
3
),
[
'Wrong rank for tensor %s [expected] [actual]'
,
image_list
[
i
].
name
,
3
,
image_rank
])
rank_assertions
.
append
(
rank_assert
)
image_shape
=
control_flow_ops
.
with_dependencies
(
[
rank_assertions
[
0
]],
tf
.
shape
(
image_list
[
0
]))
image_height
=
image_shape
[
0
]
image_width
=
image_shape
[
1
]
crop_size_assert
=
tf
.
Assert
(
tf
.
logical_and
(
tf
.
greater_equal
(
image_height
,
crop_height
),
tf
.
greater_equal
(
image_width
,
crop_width
)),
[
'Crop size greater than the image size.'
])
asserts
=
[
rank_assertions
[
0
],
crop_size_assert
]
for
i
in
range
(
1
,
len
(
image_list
)):
image
=
image_list
[
i
]
asserts
.
append
(
rank_assertions
[
i
])
shape
=
control_flow_ops
.
with_dependencies
([
rank_assertions
[
i
]],
tf
.
shape
(
image
))
height
=
shape
[
0
]
width
=
shape
[
1
]
height_assert
=
tf
.
Assert
(
tf
.
equal
(
height
,
image_height
),
[
'Wrong height for tensor %s [expected][actual]'
,
image
.
name
,
height
,
image_height
])
width_assert
=
tf
.
Assert
(
tf
.
equal
(
width
,
image_width
),
[
'Wrong width for tensor %s [expected][actual]'
,
image
.
name
,
width
,
image_width
])
asserts
.
extend
([
height_assert
,
width_assert
])
# Create a random bounding box.
#
# Use tf.random_uniform and not numpy.random.rand as doing the former would
# generate random numbers at graph eval time, unlike the latter which
# generates random numbers at graph definition time.
max_offset_height
=
control_flow_ops
.
with_dependencies
(
asserts
,
tf
.
reshape
(
image_height
-
crop_height
+
1
,
[]))
max_offset_width
=
control_flow_ops
.
with_dependencies
(
asserts
,
tf
.
reshape
(
image_width
-
crop_width
+
1
,
[]))
offset_height
=
tf
.
random_uniform
(
[],
maxval
=
max_offset_height
,
dtype
=
tf
.
int32
)
offset_width
=
tf
.
random_uniform
(
[],
maxval
=
max_offset_width
,
dtype
=
tf
.
int32
)
return
[
_crop
(
image
,
offset_height
,
offset_width
,
crop_height
,
crop_width
)
for
image
in
image_list
]
def
_central_crop
(
image_list
,
crop_height
,
crop_width
):
"""Performs central crops of the given image list.
Args:
image_list: a list of image tensors of the same dimension but possibly
varying channel.
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Returns:
the list of cropped images.
"""
outputs
=
[]
for
image
in
image_list
:
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
offset_height
=
(
image_height
-
crop_height
)
/
2
offset_width
=
(
image_width
-
crop_width
)
/
2
outputs
.
append
(
_crop
(
image
,
offset_height
,
offset_width
,
crop_height
,
crop_width
))
return
outputs
def
_mean_image_subtraction
(
image
,
means
):
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if
image
.
get_shape
().
ndims
!=
3
:
raise
ValueError
(
'Input must be of size [height, width, C>0]'
)
num_channels
=
image
.
get_shape
().
as_list
()[
-
1
]
if
len
(
means
)
!=
num_channels
:
raise
ValueError
(
'len(means) must match the number of channels'
)
channels
=
tf
.
split
(
2
,
num_channels
,
image
)
for
i
in
range
(
num_channels
):
channels
[
i
]
-=
means
[
i
]
return
tf
.
concat
(
2
,
channels
)
def
_smallest_size_at_least
(
height
,
width
,
smallest_side
):
"""Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
smallest_side: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: and int32 scalar tensor indicating the new width.
"""
smallest_side
=
tf
.
convert_to_tensor
(
smallest_side
,
dtype
=
tf
.
int32
)
height
=
tf
.
to_float
(
height
)
width
=
tf
.
to_float
(
width
)
smallest_side
=
tf
.
to_float
(
smallest_side
)
scale
=
tf
.
cond
(
tf
.
greater
(
height
,
width
),
lambda
:
smallest_side
/
width
,
lambda
:
smallest_side
/
height
)
new_height
=
tf
.
to_int32
(
height
*
scale
)
new_width
=
tf
.
to_int32
(
width
*
scale
)
return
new_height
,
new_width
def
_aspect_preserving_resize
(
image
,
smallest_side
):
"""Resize images preserving the original aspect ratio.
Args:
image: A 3-D image `Tensor`.
smallest_side: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
smallest_side
=
tf
.
convert_to_tensor
(
smallest_side
,
dtype
=
tf
.
int32
)
shape
=
tf
.
shape
(
image
)
height
=
shape
[
0
]
width
=
shape
[
1
]
new_height
,
new_width
=
_smallest_size_at_least
(
height
,
width
,
smallest_side
)
image
=
tf
.
expand_dims
(
image
,
0
)
resized_image
=
tf
.
image
.
resize_bilinear
(
image
,
[
new_height
,
new_width
],
align_corners
=
False
)
resized_image
=
tf
.
squeeze
(
resized_image
)
resized_image
.
set_shape
([
None
,
None
,
3
])
return
resized_image
def
preprocess_for_train
(
image
,
output_height
,
output_width
,
resize_side_min
=
_RESIZE_SIDE_MIN
,
resize_side_max
=
_RESIZE_SIDE_MAX
):
"""Preprocesses the given image for training.
Note that the actual resizing scale is sampled from
[`resize_size_min`, `resize_size_max`].
Args:
image: A `Tensor` representing an image of arbitrary size.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
resize_side_min: The lower bound for the smallest side of the image for
aspect-preserving resizing.
resize_side_max: The upper bound for the smallest side of the image for
aspect-preserving resizing.
Returns:
A preprocessed image.
"""
resize_side
=
tf
.
random_uniform
(
[],
minval
=
resize_side_min
,
maxval
=
resize_side_max
+
1
,
dtype
=
tf
.
int32
)
image
=
_aspect_preserving_resize
(
image
,
resize_side
)
image
=
_random_crop
([
image
],
output_height
,
output_width
)[
0
]
image
.
set_shape
([
output_height
,
output_width
,
3
])
image
=
tf
.
to_float
(
image
)
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
_mean_image_subtraction
(
image
,
[
_R_MEAN
,
_G_MEAN
,
_B_MEAN
])
def
preprocess_for_eval
(
image
,
output_height
,
output_width
,
resize_side
):
"""Preprocesses the given image for evaluation.
Args:
image: A `Tensor` representing an image of arbitrary size.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
resize_side: The smallest side of the image for aspect-preserving resizing.
Returns:
A preprocessed image.
"""
image
=
_aspect_preserving_resize
(
image
,
resize_side
)
image
=
_central_crop
([
image
],
output_height
,
output_width
)[
0
]
image
.
set_shape
([
output_height
,
output_width
,
3
])
image
=
tf
.
to_float
(
image
)
return
_mean_image_subtraction
(
image
,
[
_R_MEAN
,
_G_MEAN
,
_B_MEAN
])
def
preprocess_image
(
image
,
output_height
,
output_width
,
is_training
=
False
,
resize_side_min
=
_RESIZE_SIDE_MIN
,
resize_side_max
=
_RESIZE_SIDE_MAX
):
"""Preprocesses the given image.
Args:
image: A `Tensor` representing an image of arbitrary size.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
is_training: `True` if we're preprocessing the image for training and
`False` otherwise.
resize_side_min: The lower bound for the smallest side of the image for
aspect-preserving resizing. If `is_training` is `False`, then this value
is used for rescaling.
resize_side_max: The upper bound for the smallest side of the image for
aspect-preserving resizing. If `is_training` is `False`, this value is
ignored. Otherwise, the resize side is sampled from
[resize_size_min, resize_size_max].
Returns:
A preprocessed image.
"""
if
is_training
:
return
preprocess_for_train
(
image
,
output_height
,
output_width
,
resize_side_min
,
resize_side_max
)
else
:
return
preprocess_for_eval
(
image
,
output_height
,
output_width
,
resize_side_min
)
slim/nets/lenet.py
0 → 100644
View file @
a5c4fd06
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a variant of the LeNet model definition."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
def
lenet
(
images
,
num_classes
=
10
,
is_training
=
False
,
dropout_keep_prob
=
0.5
,
prediction_fn
=
slim
.
softmax
,
scope
=
'LeNet'
):
"""Creates a variant of the LeNet model.
Note that since the output is a set of 'logits', the values fall in the
interval of (-infinity, infinity). Consequently, to convert the outputs to a
probability distribution over the characters, one will need to convert them
using the softmax function:
logits = mnist.Mnist(images, is_training=False)
probabilities = tf.nn.softmax(logits)
predictions = tf.argmax(logits, 1)
Args:
images: A batch of `Tensors` of size [batch_size, height, width, channels].
num_classes: the number of classes in the dataset.
is_training: specifies whether or not we're currently training the model.
This variable will determine the behaviour of the dropout layer.
dropout_keep_prob: the percentage of activation values that are retained.
prediction_fn: a function to get predictions out of logits.
scope: Optional variable_scope.
Returns:
logits: the pre-softmax activations, a tensor of size
[batch_size, `num_classes`]
end_points: a dictionary from components of the network to the corresponding
activation.
"""
end_points
=
{}
with
tf
.
variable_scope
(
scope
,
'LeNet'
,
[
images
,
num_classes
]):
net
=
slim
.
conv2d
(
images
,
32
,
[
5
,
5
],
scope
=
'conv1'
)
net
=
slim
.
max_pool2d
(
net
,
[
2
,
2
],
2
,
scope
=
'pool1'
)
net
=
slim
.
conv2d
(
net
,
64
,
[
5
,
5
],
scope
=
'conv2'
)
net
=
slim
.
max_pool2d
(
net
,
[
2
,
2
],
2
,
scope
=
'pool2'
)
net
=
slim
.
flatten
(
net
)
end_points
[
'Flatten'
]
=
net
net
=
slim
.
fully_connected
(
net
,
1024
,
scope
=
'fc3'
)
net
=
slim
.
dropout
(
net
,
dropout_keep_prob
,
is_training
=
is_training
,
scope
=
'dropout3'
)
logits
=
slim
.
fully_connected
(
net
,
num_classes
,
activation_fn
=
None
,
scope
=
'fc4'
)
end_points
[
'Logits'
]
=
logits
end_points
[
'Predictions'
]
=
prediction_fn
(
logits
,
scope
=
'Predictions'
)
return
logits
,
end_points
lenet
.
default_image_size
=
28
def
lenet_arg_scope
(
weight_decay
=
0.0
):
"""Defines the default lenet argument scope.
Args:
weight_decay: The weight decay to use for regularizing the model.
Returns:
An `arg_scope` to use for the inception v3 model.
"""
with
slim
.
arg_scope
(
[
slim
.
conv2d
,
slim
.
fully_connected
],
weights_regularizer
=
slim
.
l2_regularizer
(
weight_decay
),
weights_initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.1
),
activation_fn
=
tf
.
nn
.
relu
)
as
sc
:
return
sc
slim/scripts/train_lenet_on_mnist.sh
0 → 100644
View file @
a5c4fd06
#!/bin/bash
#
# Before running this script, make sure you've followed the instructions for
# downloading and converting the MNIST dataset.
# See slim/datasets/download_and_convert_mnist.py.
#
# Usage:
# ./slim/scripts/train_lenet_on_mnist.sh
# Compile the training and evaluation binaries
bazel build slim:train
bazel build slim:eval
# Where the checkpoint and logs will be saved to.
TRAIN_DIR
=
/tmp/lenet-model
# Where the dataset was saved to.
DATASET_DIR
=
/tmp/mnist
# Run training.
./bazel-bin/slim/train
\
--train_dir
=
${
TRAIN_DIR
}
\
--dataset_name
=
mnist
\
--dataset_split_name
=
train
\
--dataset_dir
=
${
DATASET_DIR
}
\
--model_name
=
lenet
\
--preprocessing_name
=
lenet
\
--max_number_of_steps
=
20000
\
--learning_rate
=
0.01
\
--save_interval_secs
=
60
\
--save_summaries_secs
=
60
\
--optimizer
=
sgd
\
--learning_rate_decay_factor
=
1.0
--weight_decay
=
0
# Run evaluation.
./blaze-bin/slim/eval
\
--checkpoint_path
=
${
TRAIN_DIR
}
\
--eval_dir
=
${
TRAIN_DIR
}
\
--dataset_name
=
mnist
\
--dataset_split_name
=
test
\
--dataset_dir
=
${
DATASET_DIR
}
\
--model_name
=
lenet
slim/train.py
0 → 100644
View file @
a5c4fd06
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic training script that trains a given model a specified dataset."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
tensorflow.python.ops
import
control_flow_ops
from
slim.datasets
import
dataset_factory
from
slim.models
import
model_deploy
from
slim.models
import
model_factory
from
slim.models
import
preprocessing_factory
slim
=
tf
.
contrib
.
slim
tf
.
app
.
flags
.
DEFINE_string
(
'master'
,
''
,
'The address of the TensorFlow master to use.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'train_dir'
,
'/tmp/tfmodel/'
,
'Directory where checkpoints and event logs are written to.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'num_clones'
,
1
,
'Number of model clones to deploy.'
)
tf
.
app
.
flags
.
DEFINE_boolean
(
'clone_on_cpu'
,
False
,
'Use CPUs to deploy clones.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'worker_replicas'
,
1
,
'Number of worker replicas.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'num_ps_tasks'
,
0
,
'The number of parameter servers. If the value is 0, then the parameters '
'are handled locally by the worker.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'num_readers'
,
4
,
'The number of parallel readers that read data from the dataset.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'num_preprocessing_threads'
,
4
,
'The number of threads used to create the batches.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'log_every_n_steps'
,
5
,
'The frequency with which logs are print.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'save_summaries_secs'
,
600
,
'The frequency with which summaries are saved, in seconds.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'save_interval_secs'
,
600
,
'The frequency with which the model is saved, in seconds.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'task'
,
0
,
'Task id of the replica running the training.'
)
######################
# Optimization Flags #
######################
tf
.
app
.
flags
.
DEFINE_float
(
'weight_decay'
,
0.00004
,
'The weight decay on the model weights.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'optimizer'
,
'rmsprop'
,
'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
'"ftrl", "momentum", "sgd" or "rmsprop".'
)
tf
.
app
.
flags
.
DEFINE_float
(
'adadelta_rho'
,
0.95
,
'The decay rate for adadelta.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'adagrad_initial_accumulator_value'
,
0.1
,
'Starting value for the AdaGrad accumulators.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'adam_beta1'
,
0.9
,
'The exponential decay rate for the 1st moment estimates.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'adam_beta2'
,
0.999
,
'The exponential decay rate for the 2nd moment estimates.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'opt_epsilon'
,
1.0
,
'Epsilon term for the optimizer.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'ftrl_learning_rate_power'
,
-
0.5
,
'The learning rate power.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'ftrl_initial_accumulator_value'
,
0.1
,
'Starting value for the FTRL accumulators.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'ftrl_l1'
,
0.0
,
'The FTRL l1 regularization strength.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'ftrl_l2'
,
0.0
,
'The FTRL l2 regularization strength.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'momentum'
,
0.9
,
'The momentum for the MomentumOptimizer and RMSPropOptimizer.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'rmsprop_momentum'
,
0.9
,
'Momentum.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'rmsprop_decay'
,
0.9
,
'Decay term for RMSProp.'
)
#######################
# Learning Rate Flags #
#######################
tf
.
app
.
flags
.
DEFINE_string
(
'learning_rate_decay_type'
,
'exponential'
,
'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
' or "polynomial"'
)
tf
.
app
.
flags
.
DEFINE_float
(
'learning_rate'
,
0.01
,
'Initial learning rate.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'end_learning_rate'
,
0.0001
,
'The minimal end learning rate used by a polynomial decay learning rate.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'label_smoothing'
,
0.0
,
'The amount of label smoothing.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'learning_rate_decay_factor'
,
0.94
,
'Learning rate decay factor.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'num_epochs_per_decay'
,
2.0
,
'Number of epochs after which learning rate decays.'
)
tf
.
app
.
flags
.
DEFINE_bool
(
'sync_replicas'
,
False
,
'Whether or not to synchronize the replicas during training.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'replicas_to_aggregate'
,
1
,
'The Number of gradients to collect before updating params.'
)
tf
.
app
.
flags
.
DEFINE_float
(
'moving_average_decay'
,
None
,
'The decay to use for the moving average.'
'If left as None, then moving averages are not used.'
)
#######################
# Dataset Flags #
#######################
tf
.
app
.
flags
.
DEFINE_string
(
'dataset_name'
,
'imagenet'
,
'The name of the dataset to load.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'dataset_split_name'
,
'train'
,
'The name of the train/test split.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'dataset_dir'
,
None
,
'The directory where the dataset files are stored.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'labels_offset'
,
0
,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'model_name'
,
'inception_v3'
,
'The name of the architecture to train.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'preprocessing_name'
,
None
,
'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'The number of samples in each batch.'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'train_image_size'
,
None
,
'Train image size'
)
tf
.
app
.
flags
.
DEFINE_integer
(
'max_number_of_steps'
,
None
,
'The maximum number of training steps.'
)
#####################
# Fine-Tuning Flags #
#####################
tf
.
app
.
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'The path to a checkpoint from which to fine-tune.'
)
tf
.
app
.
flags
.
DEFINE_string
(
'checkpoint_exclude_scopes'
,
None
,
'Comma-separated list of scopes to include when fine-tuning '
'from a checkpoint.'
)
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
_configure_learning_rate
(
num_samples_per_epoch
,
global_step
):
"""Configures the learning rate.
Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
Raises:
ValueError: if
"""
decay_steps
=
int
(
num_samples_per_epoch
/
FLAGS
.
batch_size
*
FLAGS
.
num_epochs_per_decay
)
if
FLAGS
.
sync_replicas
:
decay_steps
/=
FLAGS
.
replicas_to_aggregate
if
FLAGS
.
learning_rate_decay_type
==
'exponential'
:
return
tf
.
train
.
exponential_decay
(
FLAGS
.
learning_rate
,
global_step
,
decay_steps
,
FLAGS
.
learning_rate_decay_factor
,
staircase
=
True
,
name
=
'exponential_decay_learning_rate'
)
elif
FLAGS
.
learning_rate_decay_type
==
'fixed'
:
return
tf
.
constant
(
FLAGS
.
learning_rate
,
name
=
'fixed_learning_rate'
)
elif
FLAGS
.
learning_rate_decay_type
==
'polynomial'
:
return
tf
.
train
.
polynomial_decay
(
FLAGS
.
learning_rate
,
global_step
,
decay_steps
,
FLAGS
.
end_learning_rate
,
power
=
1.0
,
cycle
=
False
,
name
=
'polynomial_decay_learning_rate'
)
else
:
raise
ValueError
(
'learning_rate_decay_type [%s] was not recognized'
,
FLAGS
.
learning_rate_decay_type
)
def
_configure_optimizer
(
learning_rate
):
"""Configures the optimizer used for training.
Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
Raises:
ValueError: if FLAGS.optimizer is not recognized.
"""
if
FLAGS
.
optimizer
==
'adadelta'
:
optimizer
=
tf
.
train
.
AdadeltaOptimizer
(
learning_rate
,
rho
=
FLAGS
.
adadelta_rho
,
epsilon
=
FLAGS
.
opt_epsilon
)
elif
FLAGS
.
optimizer
==
'adagrad'
:
optimizer
=
tf
.
train
.
AdagradOptimizer
(
learning_rate
,
initial_accumulator_value
=
FLAGS
.
adagrad_initial_accumulator_value
)
elif
FLAGS
.
optimizer
==
'adam'
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
,
beta1
=
FLAGS
.
adam_beta1
,
beta2
=
FLAGS
.
adam_beta2
,
epsilon
=
FLAGS
.
opt_epsilon
)
elif
FLAGS
.
optimizer
==
'ftrl'
:
optimizer
=
tf
.
train
.
FtrlOptimizer
(
learning_rate
,
learning_rate_power
=
FLAGS
.
ftrl_learning_rate_power
,
initial_accumulator_value
=
FLAGS
.
ftrl_initial_accumulator_value
,
l1_regularization_strength
=
FLAGS
.
ftrl_l1
,
l2_regularization_strength
=
FLAGS
.
ftrl_l2
)
elif
FLAGS
.
optimizer
==
'momentum'
:
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
,
momentum
=
FLAGS
.
momentum
,
name
=
'Momentum'
)
elif
FLAGS
.
optimizer
==
'rmsprop'
:
optimizer
=
tf
.
train
.
RMSPropOptimizer
(
learning_rate
,
decay
=
FLAGS
.
rmsprop_decay
,
momentum
=
FLAGS
.
rmsprop_momentum
,
epsilon
=
FLAGS
.
opt_epsilon
)
elif
FLAGS
.
optimizer
==
'sgd'
:
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
)
else
:
raise
ValueError
(
'Optimizer [%s] was not recognized'
,
FLAGS
.
optimizer
)
return
optimizer
def
_add_variables_summaries
(
learning_rate
):
summaries
=
[]
for
variable
in
slim
.
get_model_variables
():
summaries
.
append
(
tf
.
histogram_summary
(
variable
.
op
.
name
,
variable
))
summaries
.
append
(
tf
.
scalar_summary
(
'training/Learning Rate'
,
learning_rate
))
return
summaries
def
_get_init_fn
():
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if
FLAGS
.
checkpoint_path
is
None
:
return
None
# Warn the user if a checkpoint exists in the train_dir. Then we'll be
# ignoring the checkpoint anyway.
if
tf
.
train
.
latest_checkpoint
(
FLAGS
.
train_dir
):
tf
.
logging
.
info
(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
%
FLAGS
.
train_dir
)
return
None
exclusions
=
[]
if
FLAGS
.
checkpoint_exclude_scopes
:
exclusions
=
[
scope
.
strip
()
for
scope
in
FLAGS
.
checkpoint_exclude_scopes
.
split
(
','
)]
# TODO(sguada) variables.filter_variables()
variables_to_restore
=
[]
for
var
in
slim
.
get_model_variables
():
excluded
=
False
for
exclusion
in
exclusions
:
if
var
.
op
.
name
.
startswith
(
exclusion
):
excluded
=
True
break
if
not
excluded
:
variables_to_restore
.
append
(
var
)
return
slim
.
assign_from_checkpoint_fn
(
FLAGS
.
checkpoint_path
,
variables_to_restore
)
def
main
(
_
):
if
not
FLAGS
.
dataset_dir
:
raise
ValueError
(
'You must supply the dataset directory with --dataset_dir'
)
with
tf
.
Graph
().
as_default
():
######################
# Config model_deploy#
######################
deploy_config
=
model_deploy
.
DeploymentConfig
(
num_clones
=
FLAGS
.
num_clones
,
clone_on_cpu
=
FLAGS
.
clone_on_cpu
,
replica_id
=
FLAGS
.
task
,
num_replicas
=
FLAGS
.
worker_replicas
,
num_ps_tasks
=
FLAGS
.
num_ps_tasks
)
# Create global_step
with
tf
.
device
(
deploy_config
.
variables_device
()):
global_step
=
slim
.
create_global_step
()
######################
# Select the dataset #
######################
dataset
=
dataset_factory
.
get_dataset
(
FLAGS
.
dataset_name
,
FLAGS
.
dataset_split_name
,
FLAGS
.
dataset_dir
)
####################
# Select the model #
####################
model_fn
=
model_factory
.
get_model
(
FLAGS
.
model_name
,
num_classes
=
(
dataset
.
num_classes
-
FLAGS
.
labels_offset
),
weight_decay
=
FLAGS
.
weight_decay
,
is_training
=
True
)
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name
=
FLAGS
.
preprocessing_name
or
FLAGS
.
model_name
image_preprocessing_fn
=
preprocessing_factory
.
get_preprocessing
(
preprocessing_name
,
is_training
=
True
)
##############################################################
# Create a dataset provider that loads data from the dataset #
##############################################################
with
tf
.
device
(
deploy_config
.
inputs_device
()):
provider
=
slim
.
dataset_data_provider
.
DatasetDataProvider
(
dataset
,
num_readers
=
FLAGS
.
num_readers
,
common_queue_capacity
=
20
*
FLAGS
.
batch_size
,
common_queue_min
=
10
*
FLAGS
.
batch_size
)
[
image
,
label
]
=
provider
.
get
([
'image'
,
'label'
])
label
-=
FLAGS
.
labels_offset
if
FLAGS
.
train_image_size
is
None
:
train_image_size
=
model_fn
.
default_image_size
else
:
train_image_size
=
FLAGS
.
train_image_size
image
=
image_preprocessing_fn
(
image
,
train_image_size
,
train_image_size
)
images
,
labels
=
tf
.
train
.
batch
(
[
image
,
label
],
batch_size
=
FLAGS
.
batch_size
,
num_threads
=
FLAGS
.
num_preprocessing_threads
,
capacity
=
5
*
FLAGS
.
batch_size
)
labels
=
slim
.
one_hot_encoding
(
labels
,
dataset
.
num_classes
-
FLAGS
.
labels_offset
)
batch_queue
=
slim
.
prefetch_queue
.
prefetch_queue
(
[
images
,
labels
],
capacity
=
2
*
deploy_config
.
num_clones
)
####################
# Define the model #
####################
def
clone_fn
(
batch_queue
):
"""Allows data parallelism by creating multiple clones of the model_fn."""
images
,
labels
=
batch_queue
.
dequeue
()
logits
,
end_points
=
model_fn
(
images
)
#############################
# Specify the loss function #
#############################
if
'AuxLogits'
in
end_points
:
slim
.
losses
.
softmax_cross_entropy
(
end_points
[
'AuxLogits'
],
labels
,
label_smoothing
=
FLAGS
.
label_smoothing
,
weight
=
0.4
,
scope
=
'aux_loss'
)
slim
.
losses
.
softmax_cross_entropy
(
logits
,
labels
,
label_smoothing
=
FLAGS
.
label_smoothing
,
weight
=
1.0
)
# Gather initial summaries.
summaries
=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
))
clones
=
model_deploy
.
create_clones
(
deploy_config
,
clone_fn
,
[
batch_queue
])
first_clone_scope
=
deploy_config
.
clone_scope
(
0
)
# Gather update_ops from the first clone. These contain, for example,
# the updates for the batch_norm variables created by model_fn.
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
first_clone_scope
)
# Add summaries for losses.
for
loss
in
tf
.
get_collection
(
tf
.
GraphKeys
.
LOSSES
,
first_clone_scope
):
tf
.
scalar_summary
(
'losses/%s'
%
loss
.
op
.
name
,
loss
)
# Add summaries for variables.
for
variable
in
slim
.
get_model_variables
():
summaries
.
add
(
tf
.
histogram_summary
(
variable
.
op
.
name
,
variable
))
#################################
# Configure the moving averages #
#################################
if
FLAGS
.
moving_average_decay
:
moving_average_variables
=
slim
.
get_model_variables
()
variable_averages
=
tf
.
train
.
ExponentialMovingAverage
(
FLAGS
.
moving_average_decay
,
global_step
)
else
:
moving_average_variables
,
variable_averages
=
None
,
None
#########################################
# Configure the optimization procedure. #
#########################################
with
tf
.
device
(
deploy_config
.
optimizer_device
()):
learning_rate
=
_configure_learning_rate
(
dataset
.
num_samples
,
global_step
)
optimizer
=
_configure_optimizer
(
learning_rate
)
summaries
.
add
(
tf
.
scalar_summary
(
'learning_rate'
,
learning_rate
,
name
=
'learning_rate'
))
if
FLAGS
.
sync_replicas
:
# If sync_replicas is enabled, the averaging will be done in the chief
# queue runner.
optimizer
=
tf
.
train
.
SyncReplicasOptimizer
(
opt
=
optimizer
,
replicas_to_aggregate
=
FLAGS
.
replicas_to_aggregate
,
variable_averages
=
variable_averages
,
variables_to_average
=
moving_average_variables
,
replica_id
=
tf
.
constant
(
FLAGS
.
task
,
tf
.
int32
,
shape
=
()),
total_num_replicas
=
FLAGS
.
worker_replicas
)
elif
FLAGS
.
moving_average_decay
:
# Update ops executed locally by trainer.
update_ops
.
append
(
variable_averages
.
apply
(
moving_average_variables
))
# TODO(sguada) Refactor into function that takes the clones and optimizer
# and returns a train_tensor and summary_op
total_loss
,
clones_gradients
=
model_deploy
.
optimize_clones
(
clones
,
optimizer
)
# Add total_loss to summary.
summaries
.
add
(
tf
.
scalar_summary
(
'total_loss'
,
total_loss
,
name
=
'total_loss'
))
# Create gradient updates.
grad_updates
=
optimizer
.
apply_gradients
(
clones_gradients
,
global_step
=
global_step
)
update_ops
.
append
(
grad_updates
)
update_op
=
tf
.
group
(
*
update_ops
)
train_tensor
=
control_flow_ops
.
with_dependencies
([
update_op
],
total_loss
,
name
=
'train_op'
)
# Add the summaries from the first clone. These contain the summaries
# created by model_fn and either optimize_clones() or _gather_clone_loss().
summaries
|=
set
(
tf
.
get_collection
(
tf
.
GraphKeys
.
SUMMARIES
,
first_clone_scope
))
# Merge all summaries together.
summary_op
=
tf
.
merge_summary
(
list
(
summaries
),
name
=
'summary_op'
)
###########################
# Kicks off the training. #
###########################
slim
.
learning
.
train
(
train_tensor
,
logdir
=
FLAGS
.
train_dir
,
master
=
FLAGS
.
master
,
is_chief
=
(
FLAGS
.
task
==
0
),
init_fn
=
_get_init_fn
(),
summary_op
=
summary_op
,
number_of_steps
=
FLAGS
.
max_number_of_steps
,
log_every_n_steps
=
FLAGS
.
log_every_n_steps
,
save_summaries_secs
=
FLAGS
.
save_summaries_secs
,
save_interval_secs
=
FLAGS
.
save_interval_secs
,
sync_optimizer
=
optimizer
if
FLAGS
.
sync_replicas
else
None
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment