Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
aa90d31f
Commit
aa90d31f
authored
Apr 08, 2020
by
Thor Johnsen
Browse files
Add internal pipelining option
parent
be4c41c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
9 deletions
+24
-9
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+24
-9
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
aa90d31f
...
...
@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
):
dwu_num_chunks
=
4
,
predivide
=
True
,
internal_pipeline
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
...
@@ -79,6 +79,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_internal_pipeline
=
internal_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
...
...
@@ -209,19 +210,33 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_block
=
flat_grads
[
start
:
end
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
works
=
[
work
]
if
self
.
_num_groups
>
1
:
work
.
wait
()
if
self
.
_internal_pipeline
:
works
=
[]
chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
for
i
in
range
(
self
.
_num_chunks
):
chunks
=
[
grad_shards
[
j
][
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
]
for
j
in
range
(
self
.
_group_size
)]
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
chunks
[
self
.
_rank_in_group
],
chunks
,
group
=
self
.
_rs_pg
[
i
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
chunks
[
self
.
_rank_in_group
],
chunks
,
group
=
self
.
_rs_pg
[
i
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
if
self
.
_num_groups
>
1
:
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
.
append
(
work
)
else
:
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
works
=
[
work
]
if
self
.
_num_groups
>
1
:
work
.
wait
()
works
=
[]
chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
for
i
in
range
(
self
.
_num_chunks
):
chunks
=
[
grad_shards
[
j
][
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
]
for
j
in
range
(
self
.
_group_size
)]
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
.
append
(
work
)
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment