Commit 91c7b91f authored by Yuefeng Zhou's avatar Yuefeng Zhou Committed by Sergio Guadarrama
Browse files

Variables defined in ExponentialMovingAverage need not to be shared. (#778)

* Variables defined in ExponentialMovingAverage need not to be shared.

* Address comments.
parent 9bec9e8f
......@@ -79,7 +79,7 @@ RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
def _tower_loss(images, labels, num_classes, scope):
def _tower_loss(images, labels, num_classes, scope, reuse_variables=None):
"""Calculate the total loss on a single tower running the ImageNet model.
We perform 'batch splitting'. This means that we cut up a batch across
......@@ -103,9 +103,10 @@ def _tower_loss(images, labels, num_classes, scope):
restore_logits = not FLAGS.fine_tune
# Build inference Graph.
logits = inception.inference(images, num_classes, for_training=True,
restore_logits=restore_logits,
scope=scope)
with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
logits = inception.inference(images, num_classes, for_training=True,
restore_logits=restore_logits,
scope=scope)
# Build the portion of the Graph calculating the losses. Note that we will
# assemble the total_loss using a custom function below.
......@@ -220,13 +221,14 @@ def train(dataset):
# Number of classes in the Dataset label set plus 1.
# Label 0 is reserved for an (unused) background class.
num_classes = dataset.num_classes() + 1
# Split the batch of images and labels for towers.
images_splits = tf.split(0, FLAGS.num_gpus, images)
labels_splits = tf.split(0, FLAGS.num_gpus, labels)
# Calculate the gradients for each model tower.
tower_grads = []
reuse_variables = None
for i in xrange(FLAGS.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
......@@ -236,10 +238,10 @@ def train(dataset):
# function constructs the entire ImageNet model but shares the
# variables across all towers.
loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
scope)
scope, reuse_variables)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
reuse_variables = True
# Retain the summaries from the final tower.
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment