Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1c2ba890
"tests/vscode:/vscode.git/clone" did not exist on "ad15947f0ea9b34e15157dfad65b25f3a98e9ac8"
Commit
1c2ba890
authored
Jun 13, 2019
by
Thor Johnsen
Committed by
mcarilli
Jun 13, 2019
Browse files
Add option to turn on/off allreduce in DDP (useful for gradient accumulation) (#356)
parent
47e3367f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
69 deletions
+79
-69
apex/parallel/distributed.py
apex/parallel/distributed.py
+79
-69
No files found.
apex/parallel/distributed.py
View file @
1c2ba890
...
@@ -216,6 +216,8 @@ class DistributedDataParallel(Module):
...
@@ -216,6 +216,8 @@ class DistributedDataParallel(Module):
self
.
module
=
module
self
.
module
=
module
self
.
disable_allreduce
=
False
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
for
param
in
self
.
module
.
parameters
():
for
param
in
self
.
module
.
parameters
():
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
...
@@ -250,6 +252,12 @@ class DistributedDataParallel(Module):
...
@@ -250,6 +252,12 @@ class DistributedDataParallel(Module):
del
attrs
[
'self.reduction_event'
]
del
attrs
[
'self.reduction_event'
]
return
attrs
return
attrs
def
turn_on_allreduce
(
self
):
self
.
disable_allreduce
=
False
def
turn_off_allreduce
(
self
):
self
.
disable_allreduce
=
True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# regenerate their bucket structures to match.
def
sync_bucket_structure
(
self
):
def
sync_bucket_structure
(
self
):
...
@@ -327,6 +335,7 @@ class DistributedDataParallel(Module):
...
@@ -327,6 +335,7 @@ class DistributedDataParallel(Module):
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
if
not
self
.
disable_allreduce
:
if
self
.
delay_allreduce
or
self
.
needs_refresh
:
if
self
.
delay_allreduce
or
self
.
needs_refresh
:
# TODO: How do we want to handle multiple backward passes between
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
# each forward, e.g., backward passes with retain_graph=True?
...
@@ -470,6 +479,7 @@ class DistributedDataParallel(Module):
...
@@ -470,6 +479,7 @@ class DistributedDataParallel(Module):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
result
=
self
.
module
(
*
inputs
,
**
kwargs
)
result
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
not
self
.
disable_allreduce
:
if
not
self
.
delay_allreduce
:
if
not
self
.
delay_allreduce
:
param_list
=
[
param
for
param
in
self
.
module
.
parameters
()
if
param
.
requires_grad
]
param_list
=
[
param
for
param
in
self
.
module
.
parameters
()
if
param
.
requires_grad
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment