Commit 09a32f32 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Update slim/ (#2307)

parent 42f507f5
......@@ -62,6 +62,7 @@ from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
......
......@@ -19,20 +19,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
slim = tf.contrib.slim
class NetworksTest(tf.test.TestCase):
def testGetNetworkFn(self):
def testGetNetworkFnFirstHalf(self):
batch_size = 5
num_classes = 1000
for net in nets_factory.networks_map:
with self.test_session():
for net in nets_factory.networks_map.keys()[:10]:
with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes)
# Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224)
......@@ -43,19 +42,20 @@ class NetworksTest(tf.test.TestCase):
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
def testGetNetworkFnArgScope(self):
def testGetNetworkFnSecondHalf(self):
batch_size = 5
num_classes = 10
net = 'cifarnet'
with self.test_session(use_gpu=True):
num_classes = 1000
for net in nets_factory.networks_map.keys()[10:]:
with tf.Graph().as_default() as g, self.test_session(g):
net_fn = nets_factory.get_network_fn(net, num_classes)
# Most networks use 224 as their default_image_size
image_size = getattr(net_fn, 'default_image_size', 224)
with slim.arg_scope([slim.model_variable, slim.variable],
device='/CPU:0'):
inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
net_fn(inputs)
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0]
self.assertDeviceEqual('/CPU:0', weights.device)
logits, end_points = net_fn(inputs)
self.assertTrue(isinstance(logits, tf.Tensor))
self.assertTrue(isinstance(end_points, dict))
self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
if __name__ == '__main__':
tf.test.main()
......@@ -212,7 +212,7 @@ def preprocess_for_train(image, height, width, bbox,
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method),
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image',
......@@ -248,7 +248,7 @@ def preprocess_for_eval(image, height, width,
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
height: integer
width: integer
central_fraction: Optional Float, fraction of the image to crop.
......@@ -282,7 +282,11 @@ def preprocess_image(image, height, width,
"""Pre-process one image for training or evaluation.
Args:
image: 3-D Tensor [height, width, channels] with the image.
image: 3-D Tensor [height, width, channels] with the image. If dtype is
tf.float32 then the range should be [0, 1], otherwise it would converted
to tf.float32 assuming that the range is [0, MAX], where MAX is largest
positive representable number for int(8/16/32) data type (see
`tf.image.convert_image_dtype` for details).
height: integer, image expected height.
width: integer, image expected width.
is_training: Boolean. If true it would transform an image for train,
......
......@@ -117,6 +117,8 @@ tf.app.flags.DEFINE_float(
'momentum', 0.9,
'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
#######################
......@@ -301,7 +303,7 @@ def _configure_optimizer(learning_rate):
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=FLAGS.rmsprop_decay,
momentum=FLAGS.momentum,
momentum=FLAGS.rmsprop_momentum,
epsilon=FLAGS.opt_epsilon)
elif FLAGS.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
......@@ -309,6 +311,7 @@ def _configure_optimizer(learning_rate):
raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
return optimizer
def _get_init_fn():
"""Returns a function run by the chief worker to warm-start the training.
......@@ -450,7 +453,6 @@ def main(_):
####################
def clone_fn(batch_queue):
"""Allows data parallelism by creating multiple clones of network_fn."""
with tf.device(deploy_config.inputs_device()):
images, labels = batch_queue.dequeue()
logits, end_points = network_fn(images)
......@@ -458,12 +460,12 @@ def main(_):
# Specify the loss function #
#############################
if 'AuxLogits' in end_points:
tf.losses.softmax_cross_entropy(
logits=end_points['AuxLogits'], onehot_labels=labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels,
label_smoothing=FLAGS.label_smoothing, weights=1.0)
slim.losses.softmax_cross_entropy(
end_points['AuxLogits'], labels,
label_smoothing=FLAGS.label_smoothing, weights=0.4,
scope='aux_loss')
slim.losses.softmax_cross_entropy(
logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0)
return end_points
# Gather initial summaries.
......@@ -515,10 +517,9 @@ def main(_):
optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
total_num_replicas=FLAGS.worker_replicas,
variable_averages=variable_averages,
variables_to_average=moving_average_variables,
replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
total_num_replicas=FLAGS.worker_replicas)
variables_to_average=moving_average_variables)
elif FLAGS.moving_average_decay:
# Update ops executed locally by trainer.
update_ops.append(variable_averages.apply(moving_average_variables))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment