Commit 57f839bf authored by Sergio Guadarrama's avatar Sergio Guadarrama Committed by Martin Wicke
Browse files

Make the split of batch outside GPUs (#51)

Tentative fix for issue #47
parent c348081a
......@@ -221,27 +221,22 @@ def train(dataset):
# Label 0 is reserved for an (unused) background class.
num_classes = dataset.num_classes() + 1
# Split the batch of images and labels for towers.
images_splits = tf.split(0, FLAGS.num_gpus, images)
labels_splits = tf.split(0, FLAGS.num_gpus, labels)
# Calculate the gradients for each model tower.
tower_grads = []
for i in xrange(FLAGS.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
# Split the batch of images and labels.
batch_start = split_batch_size * i
images_batch = tf.slice(images,
begin=[batch_start, 0, 0, 0],
size=[split_batch_size, -1, -1, -1])
labels_batch = tf.slice(labels,
begin=[batch_start],
size=[split_batch_size])
# Force all Variables to reside on the CPU.
with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
# Calculate the loss for one tower of the ImageNet model. This
# function constructs the entire ImageNet model but shares the
# variables across all towers.
loss = _tower_loss(images_batch, labels_batch, num_classes, scope)
loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
scope)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment