Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d662f9ca
Commit
d662f9ca
authored
Mar 16, 2020
by
Thor Johnsen
Browse files
Rename inplace to no_copy to make effect clearer
parent
9f6c0da5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+8
-8
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
d662f9ca
...
...
@@ -154,11 +154,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_works
=
[]
import
inspect
if
if
'inplace
'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
self
.
_pg_supports_
inplace
=
True
if
'no_copy
'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
self
.
_pg_supports_
no_copy
=
True
else
:
self
.
_pg_supports_
inplace
=
False
print
(
"WARNING! torch.distributed.reduce_scatter does not support
inplace
op."
)
self
.
_pg_supports_
no_copy
=
False
print
(
"WARNING! torch.distributed.reduce_scatter does not support
no_copy
op."
)
def
set_last_step
(
self
,
last_step
):
...
...
@@ -188,8 +188,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end
=
start
+
self
.
_block_size
grad_block
=
flat_grads
[
start
:
end
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
if
self
.
_pg_supports_
inplace
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
inplace
=
True
)
if
self
.
_pg_supports_
no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
if
self
.
_num_groups
>
1
:
...
...
@@ -210,8 +210,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_end
=
shard_start
+
self
.
_shard_size
block_id
=
start
//
self
.
_block_size
self
.
_partial_step_single_shard
(
block_id
)
if
self
.
_pg_supports_
inplace
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
,
inplace
=
True
)
if
self
.
_pg_supports_
no_copy
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
)
return
work
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment