Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2e98baa7
"vscode:/vscode.git/clone" did not exist on "1e0395e7918f54a53ce14f14279ad07913038bcf"
Unverified
Commit
2e98baa7
authored
Sep 02, 2021
by
Burc Eryilmaz
Committed by
GitHub
Sep 02, 2021
Browse files
use prescaling for collective (#1157)
parent
1cb9c5c3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
4 deletions
+46
-4
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+46
-4
No files found.
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
2e98baa7
...
...
@@ -88,7 +88,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
fused_norm
=
False
,
e5m2_allgather
=
False
,
verbose
=
False
,
clip_after_ar
=
True
,
full_ar
=
False
):
full_ar
=
False
,
fuse_scale
=
False
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
...
...
@@ -121,6 +121,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_verbose
=
verbose
self
.
_clip_after_ar
=
clip_after_ar
self
.
_full_ar
=
full_ar
self
.
_fuse_scale
=
fuse_scale
self
.
_L2_grad_norm
=
None
self
.
_fused_norm
=
fused_norm
self
.
_current_process_group
=
c10d
.
_get_default_group
()
...
...
@@ -544,6 +545,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
flush_block
def
_full_all_reduce_scale
(
self
,
block_id
,
scale
):
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
ar_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
,
op
=
torch
.
distributed
.
make_nccl_premul_sum
((
scale
,)))
self
.
_reductions_works
[
block_id
]
=
works
def
_full_all_reduce
(
self
,
block_id
):
works
=
[
None
]
*
self
.
_num_chunks
...
...
@@ -555,6 +567,29 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_flat_grads_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
def
_reduce_scatter_and_all_reduce_scale
(
self
,
block_id
,
scale
):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
self
.
_l2_grad_norm_st
)
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
,
op
=
torch
.
distributed
.
make_nccl_premul_sum
((
scale
,)))
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
def
_reduce_scatter_and_all_reduce
(
self
,
block_id
):
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
...
...
@@ -620,10 +655,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
tmp
=
torch
.
cat
(((
self
.
_one
),
(
coeff
)))
index
=
(
coeff
+
1
>
coeff
).
int
()
scale
=
tmp
.
index_select
(
0
,
index
).
half
()
/
self
.
_world_size
if
not
self
.
_fuse_scale
:
self
.
_flat_grads
.
mul_
(
scale
)
if
self
.
_full_ar
:
if
self
.
_fuse_scale
:
self
.
_full_all_reduce_scale
(
block_id
,
scale
)
else
:
self
.
_full_all_reduce
(
block_id
)
else
:
if
self
.
_fuse_scale
:
self
.
_reduce_scatter_and_all_reduce_scale
(
block_id
,
scale
)
else
:
self
.
_reduce_scatter_and_all_reduce
(
block_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment