Commit 09a32f32 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Update slim/ (#2307)

parent 42f507f5
...@@ -62,6 +62,7 @@ from tensorflow.python.platform import gfile ...@@ -62,6 +62,7 @@ from tensorflow.python.platform import gfile
from datasets import dataset_factory from datasets import dataset_factory
from nets import nets_factory from nets import nets_factory
slim = tf.contrib.slim slim = tf.contrib.slim
tf.app.flags.DEFINE_string( tf.app.flags.DEFINE_string(
......
...@@ -19,20 +19,19 @@ from __future__ import absolute_import ...@@ -19,20 +19,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from nets import nets_factory from nets import nets_factory
slim = tf.contrib.slim
class NetworksTest(tf.test.TestCase): class NetworksTest(tf.test.TestCase):
def testGetNetworkFn(self): def testGetNetworkFnFirstHalf(self):
batch_size = 5 batch_size = 5
num_classes = 1000 num_classes = 1000
for net in nets_factory.networks_map: for net in nets_factory.networks_map.keys()[:10]:
with self.test_session(): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes)
# Most networks use 224 as their default_image_size # Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224) image_size = getattr(net_fn, 'default_image_size', 224)
...@@ -43,19 +42,20 @@ class NetworksTest(tf.test.TestCase): ...@@ -43,19 +42,20 @@ class NetworksTest(tf.test.TestCase):
self.assertEqual(logits.get_shape().as_list()[0], batch_size) self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes) self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
def testGetNetworkFnArgScope(self): def testGetNetworkFnSecondHalf(self):
batch_size = 5 batch_size = 5
num_classes = 10 num_classes = 1000
net = 'cifarnet' for net in nets_factory.networks_map.keys()[10:]:
with self.test_session(use_gpu=True): with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes) net_fn = nets_factory.get_network_fn(net, num_classes)
image_size = getattr(net_fn, 'default_image_size', 224) # Most networks use 224 as their default_image_size
with slim.arg_scope([slim.model_variable, slim.variable], image_size = getattr(net_fn, 'default_image_size', 224)
device='/CPU:0'):
inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
net_fn(inputs) logits, end_points = net_fn(inputs)
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0] self.assertTrue(isinstance(logits, tf.Tensor))
self.assertDeviceEqual('/CPU:0', weights.device) self.assertTrue(isinstance(end_points, dict))
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -212,7 +212,7 @@ def preprocess_for_train(image, height, width, bbox, ...@@ -212,7 +212,7 @@ def preprocess_for_train(image, height, width, bbox,
num_resize_cases = 1 if fast_mode else 4 num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector( distorted_image = apply_with_random_selector(
distorted_image, distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method), lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=num_resize_cases) num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image', tf.summary.image('cropped_resized_image',
...@@ -248,7 +248,7 @@ def preprocess_for_eval(image, height, width, ...@@ -248,7 +248,7 @@ def preprocess_for_eval(image, height, width,
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range [0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
height: integer height: integer
width: integer width: integer
central_fraction: Optional Float, fraction of the image to crop. central_fraction: Optional Float, fraction of the image to crop.
...@@ -282,7 +282,11 @@ def preprocess_image(image, height, width, ...@@ -282,7 +282,11 @@ def preprocess_image(image, height, width,
"""Pre-process one image for training or evaluation. """Pre-process one image for training or evaluation.
Args: Args:
image: 3-D Tensor [height, width, channels] with the image. image: 3-D Tensor [height, width, channels] with the image. If dtype is
tf.float32 then the range should be [0, 1], otherwise it would converted
to tf.float32 assuming that the range is [0, MAX], where MAX is largest
positive representable number for int(8/16/32) data type (see
`tf.image.convert_image_dtype` for details).
height: integer, image expected height. height: integer, image expected height.
width: integer, image expected width. width: integer, image expected width.
is_training: Boolean. If true it would transform an image for train, is_training: Boolean. If true it would transform an image for train,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# of the model, and the input image size, which can be 224, 192, 160, or 128 # of the model, and the input image size, which can be 224, 192, 160, or 128
# pixels, and affects the amount of computation needed, and the latency. # pixels, and affects the amount of computation needed, and the latency.
# Here's an example generating a frozen model from pretrained weights: # Here's an example generating a frozen model from pretrained weights:
# #
set -e set -e
...@@ -20,7 +20,7 @@ print_usage () { ...@@ -20,7 +20,7 @@ print_usage () {
echo "Creates a frozen mobilenet model suitable for mobile use" echo "Creates a frozen mobilenet model suitable for mobile use"
echo "Usage:" echo "Usage:"
echo "$0 <mobilenet version> <input size> [checkpoint path]" echo "$0 <mobilenet version> <input size> [checkpoint path]"
} }
MOBILENET_VERSION=$1 MOBILENET_VERSION=$1
IMAGE_SIZE=$2 IMAGE_SIZE=$2
......
...@@ -117,6 +117,8 @@ tf.app.flags.DEFINE_float( ...@@ -117,6 +117,8 @@ tf.app.flags.DEFINE_float(
'momentum', 0.9, 'momentum', 0.9,
'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
####################### #######################
...@@ -301,7 +303,7 @@ def _configure_optimizer(learning_rate): ...@@ -301,7 +303,7 @@ def _configure_optimizer(learning_rate):
optimizer = tf.train.RMSPropOptimizer( optimizer = tf.train.RMSPropOptimizer(
learning_rate, learning_rate,
decay=FLAGS.rmsprop_decay, decay=FLAGS.rmsprop_decay,
momentum=FLAGS.momentum, momentum=FLAGS.rmsprop_momentum,
epsilon=FLAGS.opt_epsilon) epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'sgd': elif FLAGS.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
...@@ -309,6 +311,7 @@ def _configure_optimizer(learning_rate): ...@@ -309,6 +311,7 @@ def _configure_optimizer(learning_rate):
raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer) raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
return optimizer return optimizer
def _get_init_fn(): def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training. """Returns a function run by the chief worker to warm-start the training.
...@@ -450,20 +453,19 @@ def main(_): ...@@ -450,20 +453,19 @@ def main(_):
#################### ####################
def clone_fn(batch_queue): def clone_fn(batch_queue):
"""Allows data parallelism by creating multiple clones of network_fn.""" """Allows data parallelism by creating multiple clones of network_fn."""
with tf.device(deploy_config.inputs_device()): images, labels = batch_queue.dequeue()
images, labels = batch_queue.dequeue()
logits, end_points = network_fn(images) logits, end_points = network_fn(images)
############################# #############################
# Specify the loss function # # Specify the loss function #
############################# #############################
if 'AuxLogits' in end_points: if 'AuxLogits' in end_points:
tf.losses.softmax_cross_entropy( slim.losses.softmax_cross_entropy(
logits=end_points['AuxLogits'], onehot_labels=labels, end_points['AuxLogits'], labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') label_smoothing=FLAGS.label_smoothing, weights=0.4,
tf.losses.softmax_cross_entropy( scope='aux_loss')
logits=logits, onehot_labels=labels, slim.losses.softmax_cross_entropy(
label_smoothing=FLAGS.label_smoothing, weights=1.0) logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
return end_points return end_points
# Gather initial summaries. # Gather initial summaries.
...@@ -515,10 +517,9 @@ def main(_): ...@@ -515,10 +517,9 @@ def main(_):
optimizer = tf.train.SyncReplicasOptimizer( optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer, opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate, replicas_to_aggregate=FLAGS.replicas_to_aggregate,
total_num_replicas=FLAGS.worker_replicas,
variable_averages=variable_averages, variable_averages=variable_averages,
variables_to_average=moving_average_variables, variables_to_average=moving_average_variables)
replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
total_num_replicas=FLAGS.worker_replicas)
elif FLAGS.moving_average_decay: elif FLAGS.moving_average_decay:
# Update ops executed locally by trainer. # Update ops executed locally by trainer.
update_ops.append(variable_averages.apply(moving_average_variables)) update_ops.append(variable_averages.apply(moving_average_variables))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment