Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9e8fd6d9
Commit
9e8fd6d9
authored
Jun 08, 2017
by
Toby Boyd
Browse files
Fixed typo and multi-gpu processing same batch on each gpu
parent
c3e2ae5e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
2 deletions
+6
-2
tutorials/image/cifar10/cifar10_multi_gpu_train.py
tutorials/image/cifar10/cifar10_multi_gpu_train.py
+5
-1
tutorials/image/cifar10/cifar10_train.py
tutorials/image/cifar10/cifar10_train.py
+1
-1
No files found.
tutorials/image/cifar10/cifar10_multi_gpu_train.py
View file @
9e8fd6d9
...
...
@@ -138,6 +138,7 @@ def average_gradients(tower_grads):
def
train
():
print
(
FLAGS
.
batch_size
)
"""Train CIFAR-10 for a number of steps."""
with
tf
.
Graph
().
as_default
(),
tf
.
device
(
'/cpu:0'
):
# Create a variable to count the number of train() calls. This equals the
...
...
@@ -163,13 +164,16 @@ def train():
# Get images and labels for CIFAR-10.
images
,
labels
=
cifar10
.
distorted_inputs
()
batch_queue
=
tf
.
contrib
.
slim
.
prefetch_queue
.
prefetch_queue
(
[
images
,
labels
],
capacity
=
2
*
FLAGS
.
num_gpus
)
# Calculate the gradients for each model tower.
tower_grads
=
[]
with
tf
.
variable_scope
(
tf
.
get_variable_scope
()):
for
i
in
xrange
(
FLAGS
.
num_gpus
):
with
tf
.
device
(
'/gpu:%d'
%
i
):
with
tf
.
name_scope
(
'%s_%d'
%
(
cifar10
.
TOWER_NAME
,
i
))
as
scope
:
# Dequeues one batch for the GPU
images
,
labels
=
batch_queue
.
dequeue
()
# Calculate the loss for one tower of the CIFAR model. This function
# constructs the entire CIFAR model but shares the variables across
# all towers.
...
...
tutorials/image/cifar10/cifar10_train.py
View file @
9e8fd6d9
...
...
@@ -64,7 +64,7 @@ def train():
# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with
tf
.
device
(
'/
CPU
:0'
):
with
tf
.
device
(
'/
cpu
:0'
):
images
,
labels
=
cifar10
.
distorted_inputs
()
# Build a Graph that computes the logits predictions from the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment