"...resnet50_tensorflow.git" did not exist on "eb62b9179193fa09c612a7e72b9f4bd09b9d2279"
Commit 9e8fd6d9 authored by Toby Boyd's avatar Toby Boyd
Browse files

Fixed typo and multi-gpu processing same batch on each gpu

parent c3e2ae5e
...@@ -138,6 +138,7 @@ def average_gradients(tower_grads): ...@@ -138,6 +138,7 @@ def average_gradients(tower_grads):
def train(): def train():
print(FLAGS.batch_size)
"""Train CIFAR-10 for a number of steps.""" """Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default(), tf.device('/cpu:0'): with tf.Graph().as_default(), tf.device('/cpu:0'):
# Create a variable to count the number of train() calls. This equals the # Create a variable to count the number of train() calls. This equals the
...@@ -163,13 +164,16 @@ def train(): ...@@ -163,13 +164,16 @@ def train():
# Get images and labels for CIFAR-10. # Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs() images, labels = cifar10.distorted_inputs()
batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=2 * FLAGS.num_gpus)
# Calculate the gradients for each model tower. # Calculate the gradients for each model tower.
tower_grads = [] tower_grads = []
with tf.variable_scope(tf.get_variable_scope()): with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(FLAGS.num_gpus): for i in xrange(FLAGS.num_gpus):
with tf.device('/gpu:%d' % i): with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
# Dequeues one batch for the GPU
images, labels = batch_queue.dequeue()
# Calculate the loss for one tower of the CIFAR model. This function # Calculate the loss for one tower of the CIFAR model. This function
# constructs the entire CIFAR model but shares the variables across # constructs the entire CIFAR model but shares the variables across
# all towers. # all towers.
......
...@@ -64,7 +64,7 @@ def train(): ...@@ -64,7 +64,7 @@ def train():
# Get images and labels for CIFAR-10. # Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down. # GPU and resulting in a slow down.
with tf.device('/CPU:0'): with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs() images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the # Build a Graph that computes the logits predictions from the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment