Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e5c8e3c9
"...text-generation-inference.git" did not exist on "773aabdda6197cae3d2092f1cd6d9ce08d649185"
Commit
e5c8e3c9
authored
Apr 01, 2020
by
Thor Johnsen
Browse files
Add separate dwu_num_chunks argument
parent
f2c9aa33
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+8
-5
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
e5c8e3c9
...
@@ -45,7 +45,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -45,7 +45,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
):
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -76,6 +77,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -76,6 +77,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_overlap_reductions
=
overlap_reductions
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
...
@@ -202,11 +204,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -202,11 +204,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works
=
[
work
]
works
=
[
work
]
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
sliver_size
=
self
.
_shard_size
//
len
(
self
.
_ar_pg
)
sliver_size
=
self
.
_shard_size
//
self
.
_num_chunks
assert
((
sliver_size
*
self
.
_num_chunks
)
==
self
.
_shard_size
),
"Shard size not a multiple of dwu_num_chunks"
works
=
[]
works
=
[]
f
or
i
,
ar_pg
in
enumerate
(
self
.
_ar_pg
):
w
or
k
.
wait
()
work
.
wait
()
for
i
in
range
(
self
.
_num_chunks
):
works
.
append
(
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
][
i
*
sliver_size
:(
i
+
1
)
*
sliver_size
],
group
=
ar_pg
,
async_op
=
True
)
)
works
.
append
(
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
][
i
*
sliver_size
:(
i
+
1
)
*
sliver_size
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)]
,
async_op
=
True
)
)
if
self
.
_compute_L2_grad_norm
:
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment