Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
378ce1e1
Unverified
Commit
378ce1e1
authored
Jun 11, 2018
by
mcarilli
Committed by
GitHub
Jun 11, 2018
Browse files
Merge pull request #9 from NVIDIA/imagenet_fix
DDP fix, imagenet speedup
parents
fb075b86
06ee98c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
22 deletions
+20
-22
apex/parallel/distributed.py
apex/parallel/distributed.py
+16
-14
examples/imagenet/main.py
examples/imagenet/main.py
+4
-8
No files found.
apex/parallel/distributed.py
View file @
378ce1e1
...
...
@@ -79,17 +79,19 @@ class DistributedDataParallel(Module):
def
create_hooks
(
self
):
#all reduce gradient hook
def
allreduce_params
():
if
(
self
.
needs_reduction
):
self
.
needs_reduction
=
False
self
.
needs_refresh
=
False
else
:
if
not
self
.
needs_reduction
:
return
self
.
needs_reduction
=
False
#parameter ordering refresh
if
self
.
needs_refresh
and
not
self
.
shared_param
:
t_record
=
torch
.
cuda
.
IntTensor
(
self
.
record
)
dist
.
broadcast
(
t_record
,
0
)
self
.
record
=
[
int
(
entry
)
for
entry
in
t_record
]
self
.
needs_refresh
=
False
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
flat_dist_call
(
grads
,
dist
.
all_reduce
)
t_record
=
torch
.
cuda
.
IntTensor
(
self
.
record
)
dist
.
broadcast
(
t_record
,
0
)
self
.
record
=
[
int
(
entry
)
for
entry
in
t_record
]
def
flush_buckets
():
if
not
self
.
needs_reduction
:
...
...
@@ -184,12 +186,12 @@ class DistributedDataParallel(Module):
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if
self
.
shared_param
:
self
.
param_refs
=
[]
self
.
needs_refresh
=
True
if
not
self
.
param_refs
else
any
(
[
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
param_refs
)]
)
if
not
self
.
param_refs
or
self
.
shared_param
:
self
.
needs_refresh
=
True
else
:
self
.
needs_refresh
=
any
(
[
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
param_refs
)]
)
if
self
.
needs_refresh
:
self
.
record
=
[]
...
...
examples/imagenet/main.py
View file @
378ce1e1
...
...
@@ -131,7 +131,8 @@ def main():
if
args
.
fp16
:
model
=
network_to_half
(
model
)
if
args
.
distributed
:
model
=
DDP
(
model
)
#shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model
=
DDP
(
model
,
shared_param
=
True
)
global
model_params
,
master_params
if
args
.
fp16
:
...
...
@@ -189,19 +190,14 @@ def main():
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_collate
)
normalize
=
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
ImageFolder
(
valdir
,
transforms
.
Compose
([
transforms
.
Resize
(
val_size
),
transforms
.
CenterCrop
(
crop_size
),
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
collate_fn
=
fast_collate
)
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment