Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
73b62dde
Commit
73b62dde
authored
Aug 28, 2018
by
Christian Sarofeen
Committed by
mcarilli
Aug 28, 2018
Browse files
Add reducer class in parallel/distributed. (#37)
parent
dc41c5ce
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
3 deletions
+55
-3
apex/parallel/__init__.py
apex/parallel/__init__.py
+1
-1
apex/parallel/distributed.py
apex/parallel/distributed.py
+49
-0
examples/imagenet/main.py
examples/imagenet/main.py
+5
-2
No files found.
apex/parallel/__init__.py
View file @
73b62dde
from
.distributed
import
DistributedDataParallel
from
.distributed
import
DistributedDataParallel
,
Reducer
apex/parallel/distributed.py
View file @
73b62dde
...
@@ -34,6 +34,55 @@ def flat_dist_call(tensors, call, extra_args=None):
...
@@ -34,6 +34,55 @@ def flat_dist_call(tensors, call, extra_args=None):
for
buf
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
coalesced
,
bucket
)):
for
buf
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
coalesced
,
bucket
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
def
extract_tensors
(
maybe_tensor
,
tensor_list
):
if
torch
.
is_tensor
(
maybe_tensor
):
tensor_list
.
append
(
maybe_tensor
)
else
:
try
:
for
item
in
maybe_tensor
:
extract_tensors
(
item
,
tensor_list
)
except
TypeError
:
return
class
Reducer
(
object
):
"""
:class:`apex.parallel.Reducer` is a simple class that helps reduce a module parameters.
This class will not automatically reduce parameters in a module for the user, but it will
allow the user to call Reducer(module).reduce() which will immediately reduce all parameters.
:class:`apex.parallel.Reducer` is designed to work with
the launch utility script ``apex.parallel.multiproc.py`` or the launch utility script
``torch.distributed.launch`` with --nproc_per_node <= the number of gpus per node.
When used with these luanchers, :class:`apex.parallel.multiproc.py`
assumes 1:1 mapping of processes to GPUs.
Args:
module_or_grads_list: Either a network definition being run in multi-gpu/distributed mode.
Or an iterable of gradients to be reduced. If a list of gradients are passed in, user must
manually sync parameters with broadcast or another means. If module is passed in, this parameters
will be broadcasted from rank 0.
"""
def
__init__
(
self
,
module_or_grads_list
):
if
isinstance
(
module_or_grads_list
,
Module
):
self
.
module
=
module_or_grads_list
flat_dist_call
([
param
.
data
for
param
in
self
.
module
.
parameters
()],
dist
.
broadcast
,
(
0
,)
)
else
:
self
.
module
=
None
self
.
grads
=
[]
extract_tensors
(
module_or_grads_list
,
self
.
grads
)
def
reduce
(
self
):
if
self
.
module
:
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
flat_dist_call
(
grads
,
dist
.
all_reduce
)
else
:
flat_dist_call
(
self
.
grads
,
dist
.
all_reduce
)
class
DistributedDataParallel
(
Module
):
class
DistributedDataParallel
(
Module
):
"""
"""
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
...
...
examples/imagenet/main.py
View file @
73b62dde
...
@@ -19,7 +19,7 @@ import torchvision.models as models
...
@@ -19,7 +19,7 @@ import torchvision.models as models
import
numpy
as
np
import
numpy
as
np
try
:
try
:
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.parallel
import
Reducer
as
DDP
from
apex.fp16_utils
import
*
from
apex.fp16_utils
import
*
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
...
@@ -122,7 +122,8 @@ def main():
...
@@ -122,7 +122,8 @@ def main():
model
=
network_to_half
(
model
)
model
=
network_to_half
(
model
)
if
args
.
distributed
:
if
args
.
distributed
:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model
=
DDP
(
model
,
shared_param
=
True
)
global
reducer
reducer
=
DDP
(
model
)
global
model_params
,
master_params
global
model_params
,
master_params
if
args
.
fp16
:
if
args
.
fp16
:
...
@@ -307,6 +308,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -307,6 +308,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
if
args
.
fp16
:
if
args
.
fp16
:
model
.
zero_grad
()
model
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
reducer
.
reduce
()
model_grads_to_master_grads
(
model_params
,
master_params
)
model_grads_to_master_grads
(
model_params
,
master_params
)
if
args
.
static_loss_scale
!=
1
:
if
args
.
static_loss_scale
!=
1
:
for
param
in
master_params
:
for
param
in
master_params
:
...
@@ -316,6 +318,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -316,6 +318,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
else
:
else
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
reducer
.
reduce
()
optimizer
.
step
()
optimizer
.
step
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment