"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "69f022781394e6280e4897a125dedb48d760f16e"
Commit 97a87f9c authored by Chris Mattmann's avatar Chris Mattmann Committed by Toby Boyd
Browse files

Fix for TF-models #7216: CIFAR-10 tutorial for multi-GPU fails because full...

Fix for TF-models #7216: CIFAR-10 tutorial for multi-GPU fails because full shape isn't passed to prefetch_queue contributed by mattmann. (#7217)
parent 712f473e
...@@ -163,6 +163,8 @@ def train(): ...@@ -163,6 +163,8 @@ def train():
# Get images and labels for CIFAR-10. # Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs() images, labels = cifar10.distorted_inputs()
images = tf.reshape(images, [cifar10.FLAGS.batch_size, 24, 24, 3])
labels = tf.reshape(labels, [cifar10.FLAGS.batch_size])
batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue( batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=2 * FLAGS.num_gpus) [images, labels], capacity=2 * FLAGS.num_gpus)
# Calculate the gradients for each model tower. # Calculate the gradients for each model tower.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment