Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2622d7f1
Commit
2622d7f1
authored
Apr 16, 2020
by
Thor Johnsen
Browse files
Use glob_chunk to index streams and process groups
parent
85497632
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+8
-6
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
2622d7f1
...
@@ -222,17 +222,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -222,17 +222,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works
=
[
None
]
*
self
.
_num_chunks
works
=
[
None
]
*
self
.
_num_chunks
for
chunk
in
range
(
self
.
_num_chunks
):
for
chunk
in
range
(
self
.
_num_chunks
):
glob_chunk
=
block_id
*
self
.
_num_chunks
+
chunk
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
grad_chunk
=
grad_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
grad_shards
=
[
grad_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
rs_stream
=
self
.
_rs_st
[
chunk
%
self
.
_num_rs_pg
]
rs_stream
=
self
.
_rs_st
[
glob_
chunk
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
with
torch
.
cuda
.
stream
(
rs_stream
):
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
chunk
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
glob_
chunk
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
ar_stream
=
self
.
_ar_st
[
chunk
%
self
.
_num_ar_pg
]
ar_stream
=
self
.
_ar_st
[
glob_
chunk
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
with
torch
.
cuda
.
stream
(
ar_stream
):
work
.
wait
()
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
chunk
%
self
.
_num_ar_pg
],
async_op
=
True
)
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
glob_
chunk
%
self
.
_num_ar_pg
],
async_op
=
True
)
works
[
chunk
]
=
work
works
[
chunk
]
=
work
if
self
.
_compute_L2_grad_norm
:
if
self
.
_compute_L2_grad_norm
:
...
@@ -262,13 +263,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -262,13 +263,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works
=
[
None
]
*
self
.
_num_chunks
works
=
[
None
]
*
self
.
_num_chunks
for
chunk
in
range
(
self
.
_num_chunks
):
for
chunk
in
range
(
self
.
_num_chunks
):
glob_chunk
=
block_id
*
self
.
_num_chunks
+
chunk
new_params_chunk
=
new_params_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
new_params_chunk
=
new_params_block
[
chunk
*
self
.
_chunk_size
:(
chunk
+
1
)
*
self
.
_chunk_size
]
new_params_shards
=
[
new_params_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
new_params_shards
=
[
new_params_chunk
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
ag_stream
=
self
.
_ag_st
[
chunk
%
self
.
_num_ag_pg
]
ag_stream
=
self
.
_ag_st
[
glob_
chunk
%
self
.
_num_ag_pg
]
with
torch
.
cuda
.
stream
(
ag_stream
):
with
torch
.
cuda
.
stream
(
ag_stream
):
self
.
_reductions_works
[
block_id
][
chunk
].
wait
()
self
.
_reductions_works
[
block_id
][
chunk
].
wait
()
self
.
_partial_step_single_shard
(
block_id
,
chunk
)
self
.
_partial_step_single_shard
(
block_id
,
chunk
)
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
chunk
%
self
.
_num_ag_pg
],
async_op
=
True
,
no_copy
=
True
)
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
glob_
chunk
%
self
.
_num_ag_pg
],
async_op
=
True
,
no_copy
=
True
)
works
[
chunk
]
=
work
works
[
chunk
]
=
work
self
.
_allgather_works
[
block_id
]
=
works
self
.
_allgather_works
[
block_id
]
=
works
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment