Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f2c9aa33
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "799f5b4e12c5350872b6fe5ebc28be423d2570c3"
Commit
f2c9aa33
authored
Apr 01, 2020
by
Thor Johnsen
Browse files
Add support for 4 all-reduce IB communicators
parent
5c1cf020
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
11 deletions
+20
-11
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+20
-11
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
f2c9aa33
...
@@ -199,12 +199,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -199,12 +199,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
works
=
[
work
]
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
work
.
wait
()
sliver_size
=
self
.
_shard_size
//
len
(
self
.
_ar_pg
)
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
block_id
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
=
[]
for
i
,
ar_pg
in
enumerate
(
self
.
_ar_pg
):
work
.
wait
()
works
.
append
(
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
][
i
*
sliver_size
:(
i
+
1
)
*
sliver_size
],
group
=
ar_pg
,
async_op
=
True
)
)
if
self
.
_compute_L2_grad_norm
:
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
work
.
wait
()
for
work
in
works
:
work
.
wait
()
if
block_id
+
1
==
self
.
_num_blocks
:
if
block_id
+
1
==
self
.
_num_blocks
:
self
.
_L2_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
self
.
_L2_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
elif
block_id
!=
0
:
elif
block_id
!=
0
:
...
@@ -213,7 +220,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -213,7 +220,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
self
.
_L2_grad_norm
.
sqrt_
()
self
.
_L2_grad_norm
.
sqrt_
()
return
work
return
works
# NB!
# NB!
# self._global_scale is used by this method.
# self._global_scale is used by this method.
...
@@ -229,9 +237,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -229,9 +237,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return
work
return
work
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
work
=
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
works
=
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
if
work
is
not
None
:
for
work
in
works
:
work
.
wait
()
if
work
is
not
None
:
work
.
wait
()
return
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
return
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
...
@@ -251,8 +260,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -251,8 +260,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work
=
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
work
=
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
self
.
_works
.
append
(
work
)
else
:
else
:
work
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
work
s
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_works
.
append
(
work
)
self
.
_works
+=
work
s
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
...
@@ -463,8 +472,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -463,8 +472,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
work
s
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_works
.
append
(
work
)
self
.
_works
+=
work
s
self
.
_copy_to_fp32
=
False
self
.
_copy_to_fp32
=
False
self
.
_decomp_stats
=
None
self
.
_decomp_stats
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment