"packaging/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "7e43f29d5269858729509c0de83a124c4a6ee650"
Commit 596c9e23 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Merge pull request #911 from wookayin/cifar10

Fix bugs and API usage on cifar10 and cifar10_multi_gpu_train
parents eb62b917 9d96e9fe
...@@ -162,25 +162,26 @@ def train(): ...@@ -162,25 +162,26 @@ def train():
# Calculate the gradients for each model tower. # Calculate the gradients for each model tower.
tower_grads = [] tower_grads = []
for i in xrange(FLAGS.num_gpus): with tf.variable_scope(tf.get_variable_scope()):
with tf.device('/gpu:%d' % i): for i in xrange(FLAGS.num_gpus):
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: with tf.device('/gpu:%d' % i):
# Calculate the loss for one tower of the CIFAR model. This function with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
# constructs the entire CIFAR model but shares the variables across # Calculate the loss for one tower of the CIFAR model. This function
# all towers. # constructs the entire CIFAR model but shares the variables across
loss = tower_loss(scope) # all towers.
loss = tower_loss(scope)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables() # Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
# Retain the summaries from the final tower.
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) # Retain the summaries from the final tower.
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
# Calculate the gradients for the batch of data on this CIFAR tower.
grads = opt.compute_gradients(loss) # Calculate the gradients for the batch of data on this CIFAR tower.
grads = opt.compute_gradients(loss)
# Keep track of the gradients across all towers.
tower_grads.append(grads) # Keep track of the gradients across all towers.
tower_grads.append(grads)
# We must calculate the mean of each gradient. Note that this is the # We must calculate the mean of each gradient. Note that this is the
# synchronization point across all towers. # synchronization point across all towers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment