Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9661dbd7
Commit
9661dbd7
authored
May 14, 2018
by
Michael Carilli
Browse files
Updating distributed example
parent
789afd89
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
12 deletions
+16
-12
examples/distributed/main.py
examples/distributed/main.py
+15
-11
examples/distributed/run_distributed.sh
examples/distributed/run_distributed.sh
+1
-1
No files found.
examples/distributed/main.py
View file @
9661dbd7
...
...
@@ -6,6 +6,7 @@ import torch.nn.functional as F
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
from
torch.autograd
import
Variable
from
apex.fp16_utils
import
to_python_float
#=====START: ADDED FOR DISTRIBUTED======
'''Add custom module for distributed'''
...
...
@@ -83,8 +84,10 @@ if args.distributed:
torch
.
cuda
.
set_device
(
args
.
rank
%
torch
.
cuda
.
device_count
())
'''Initialize distributed communication'''
dist
.
init_process_group
(
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
)
dist
.
init_process_group
(
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
#=====END: ADDED FOR DISTRIBUTED======
...
...
@@ -174,18 +177,19 @@ def train(epoch):
if
batch_idx
%
args
.
log_interval
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
data
[
0
]
))
100.
*
batch_idx
/
len
(
train_loader
),
to_python_float
(
loss
.
data
)
))
def
test
():
model
.
eval
()
test_loss
=
0
correct
=
0
for
data
,
target
in
test_loader
:
with
torch
.
no_grad
():
if
args
.
cuda
:
data
,
target
=
data
.
cuda
(),
target
.
cuda
()
data
,
target
=
Variable
(
data
,
volatile
=
True
),
Variable
(
target
)
data
,
target
=
Variable
(
data
),
Variable
(
target
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
size_average
=
False
).
data
[
0
]
# sum up batch loss
test_loss
+=
to_python_float
(
F
.
nll_loss
(
output
,
target
,
size_average
=
False
).
data
)
# sum up batch loss
pred
=
output
.
data
.
max
(
1
,
keepdim
=
True
)[
1
]
# get the index of the max log-probability
correct
+=
pred
.
eq
(
target
.
data
.
view_as
(
pred
)).
cpu
().
sum
()
...
...
examples/distributed/run_distributed.sh
View file @
9661dbd7
python
-m
apex.parallel.multiproc main.py
export
CUDA_VISIBLE_DEVICES
=
0,1
;
python
-m
apex.parallel.multiproc main.py
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment