Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1c2ba890
"...text-generation-inference.git" did not exist on "17b7b75e652394379931c058a8c2db3a000b4225"
Commit
1c2ba890
authored
Jun 13, 2019
by
Thor Johnsen
Committed by
mcarilli
Jun 13, 2019
Browse files
Add option to turn on/off allreduce in DDP (useful for gradient accumulation) (#356)
parent
47e3367f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
69 deletions
+79
-69
apex/parallel/distributed.py
apex/parallel/distributed.py
+79
-69
No files found.
apex/parallel/distributed.py
View file @
1c2ba890
...
@@ -215,6 +215,8 @@ class DistributedDataParallel(Module):
...
@@ -215,6 +215,8 @@ class DistributedDataParallel(Module):
self
.
reduction_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
,
blocking
=
False
)
self
.
reduction_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
,
blocking
=
False
)
self
.
module
=
module
self
.
module
=
module
self
.
disable_allreduce
=
False
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
for
param
in
self
.
module
.
parameters
():
for
param
in
self
.
module
.
parameters
():
...
@@ -249,6 +251,12 @@ class DistributedDataParallel(Module):
...
@@ -249,6 +251,12 @@ class DistributedDataParallel(Module):
del
attrs
[
'self.reduction_stream'
]
del
attrs
[
'self.reduction_stream'
]
del
attrs
[
'self.reduction_event'
]
del
attrs
[
'self.reduction_event'
]
return
attrs
return
attrs
def
turn_on_allreduce
(
self
):
self
.
disable_allreduce
=
False
def
turn_off_allreduce
(
self
):
self
.
disable_allreduce
=
True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# regenerate their bucket structures to match.
...
@@ -327,44 +335,45 @@ class DistributedDataParallel(Module):
...
@@ -327,44 +335,45 @@ class DistributedDataParallel(Module):
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
if
self
.
delay_allreduce
or
self
.
needs_refresh
:
if
not
self
.
disable_allreduce
:
# TODO: How do we want to handle multiple backward passes between
if
self
.
delay_allreduce
or
self
.
needs_refresh
:
# each forward, e.g., backward passes with retain_graph=True?
# TODO: How do we want to handle multiple backward passes between
# needs_refresh and callback_queued are both vulnerable states.
# each forward, e.g., backward passes with retain_graph=True?
if
not
self
.
delay_allreduce
and
self
.
needs_refresh
:
# needs_refresh and callback_queued are both vulnerable states.
# Use the backward pass to build the bucket structure on the fly.
if
not
self
.
delay_allreduce
and
self
.
needs_refresh
:
active_i
=
self
.
param_id_to_active_i
[
id
(
param
)]
# Use the backward pass to build the bucket structure on the fly.
active_i
=
self
.
param_id_to_active_i
[
id
(
param
)]
# Float, half, and double tensors are grouped into buckets separately.
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
# Float, half, and double tensors are grouped into buckets separately.
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
ship_tmp_bucket
=
False
ship_tmp_bucket
=
False
if
self
.
custom_allreduce_triggers
:
if
self
.
custom_allreduce_triggers
:
if
id
(
param
)
in
self
.
allreduce_trigger_params
:
if
id
(
param
)
in
self
.
allreduce_trigger_params
:
ship_tmp_bucket
=
True
ship_tmp_bucket
=
True
else
:
else
:
self
.
tmp_numels
[
current_type
]
+=
param
.
numel
()
self
.
tmp_numels
[
current_type
]
+=
param
.
numel
()
if
self
.
tmp_numels
[
current_type
]
>=
self
.
message_size
:
if
self
.
tmp_numels
[
current_type
]
>=
self
.
message_size
:
ship_tmp_bucket
=
True
ship_tmp_bucket
=
True
# To consider: If custom_allreduce_triggers are in use, ship all
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type].
# tmp_buckets, not just tmp_buckets[current_type].
if
ship_tmp_bucket
:
if
ship_tmp_bucket
:
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_numels
[
current_type
]
=
0
self
.
tmp_numels
[
current_type
]
=
0
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
else
:
else
:
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
self
.
comm_ready_buckets
(
param
)
self
.
comm_ready_buckets
(
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
grad_accs
.
append
(
grad_acc
)
self
.
grad_accs
.
append
(
grad_acc
)
...
@@ -422,7 +431,7 @@ class DistributedDataParallel(Module):
...
@@ -422,7 +431,7 @@ class DistributedDataParallel(Module):
# training script, and overwritten in the next forward pass. So it's harmless.
# training script, and overwritten in the next forward pass. So it's harmless.
if
self
.
retain_allreduce_buffers
:
if
self
.
retain_allreduce_buffers
:
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
for
i
,
bucket
in
enumerate
(
split_buckets
):
for
i
,
bucket
in
enumerate
(
split_buckets
):
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
)
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
)
...
@@ -469,38 +478,39 @@ class DistributedDataParallel(Module):
...
@@ -469,38 +478,39 @@ class DistributedDataParallel(Module):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
result
=
self
.
module
(
*
inputs
,
**
kwargs
)
result
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
not
self
.
delay_allreduce
:
if
not
self
.
disable_allreduce
:
param_list
=
[
param
for
param
in
self
.
module
.
parameters
()
if
param
.
requires_grad
]
if
not
self
.
delay_allreduce
:
param_list
=
[
param
for
param
in
self
.
module
.
parameters
()
if
param
.
requires_grad
]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# Conditions under which to refresh self.record
# in backward has the authority to set needs_refresh to False.
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
# in backward has the authority to set needs_refresh to False.
if
((
not
self
.
active_params
)
or
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
(
len
(
param_list
)
!=
len
(
self
.
active_params
))
or
if
((
not
self
.
active_params
)
or
any
([
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
active_params
)])):
(
len
(
param_list
)
!=
len
(
self
.
active_params
))
or
self
.
needs_refresh
=
True
any
([
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
active_params
)])):
self
.
needs_refresh
=
True
if
self
.
needs_refresh
:
self
.
active_i_buckets
=
[]
if
self
.
needs_refresh
:
self
.
buckets
=
[]
self
.
active_i_buckets
=
[]
self
.
tmp_buckets
=
[[],
[],
[]]
# [running half, float, double buckets]
self
.
buckets
=
[]
self
.
tmp_numels
=
[
0
,
0
,
0
]
self
.
tmp_buckets
=
[[],
[],
[]]
# [running half, float, double buckets]
self
.
bucket_sizes
=
[]
self
.
tmp_numels
=
[
0
,
0
,
0
]
self
.
param_id_to_active_i
=
{
id
(
param
)
:
i
for
i
,
param
in
enumerate
(
param_list
)}
self
.
bucket_sizes
=
[]
self
.
param_id_to_bucket
=
{}
self
.
param_id_to_active_i
=
{
id
(
param
)
:
i
for
i
,
param
in
enumerate
(
param_list
)}
else
:
self
.
param_id_to_bucket
=
{}
self
.
buckets
=
[[
None
for
_
in
range
(
self
.
bucket_sizes
[
i
])]
else
:
for
i
in
range
(
self
.
num_buckets
)]
self
.
buckets
=
[[
None
for
_
in
range
(
self
.
bucket_sizes
[
i
])]
self
.
buckets_ready_size
=
[
0
for
i
in
range
(
self
.
num_buckets
)]
for
i
in
range
(
self
.
num_buckets
)]
if
(
self
.
retain_allreduce_buffers
):
self
.
buckets_ready_size
=
[
0
for
i
in
range
(
self
.
num_buckets
)]
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
self
.
num_buckets
)]
if
(
self
.
retain_allreduce_buffers
):
self
.
next_bucket
=
0
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
self
.
num_buckets
)]
self
.
ready_buckets_not_reduced
=
set
()
self
.
next_bucket
=
0
self
.
ready_buckets_not_reduced
=
set
()
self
.
active_params
=
param_list
self
.
active_params
=
param_list
self
.
callback_queued
=
False
self
.
callback_queued
=
False
return
result
return
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment