Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c659e564
Commit
c659e564
authored
Mar 12, 2020
by
Thor Johnsen
Browse files
Add backwards compatible support for no inplace NCCL op
parent
68715149
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
2 deletions
+16
-2
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+16
-2
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
c659e564
...
...
@@ -153,6 +153,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_blk_st
.
append
(
torch
.
cuda
.
Stream
())
self
.
_works
=
[]
import
inspect
if
if
'inplace'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
self
.
_pg_supports_inplace
=
True
else
:
self
.
_pg_supports_inplace
=
False
print
(
"WARNING! torch.distributed.reduce_scatter does not support inplace op."
)
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
...
...
@@ -180,7 +188,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end
=
start
+
self
.
_block_size
grad_block
=
flat_grads
[
start
:
end
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
inplace
=
True
)
if
self
.
_pg_supports_inplace
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
inplace
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
if
self
.
_num_groups
>
1
:
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
block_id
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
...
...
@@ -199,7 +210,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_end
=
shard_start
+
self
.
_shard_size
block_id
=
start
//
self
.
_block_size
self
.
_partial_step_single_shard
(
block_id
)
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
,
inplace
=
True
)
if
self
.
_pg_supports_inplace
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
,
inplace
=
True
)
else
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
)
return
work
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment