Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5c1cf020
Commit
5c1cf020
authored
Mar 31, 2020
by
Thor Johnsen
Browse files
Move partial_step out of complete reductions:
parent
3f4fb81f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
55 deletions
+31
-55
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+31
-55
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
5c1cf020
...
@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
full_pipeline
=
True
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
):
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -78,7 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -78,7 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_blocks
=
dwu_num_blocks
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
None
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
...
@@ -202,6 +202,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -202,6 +202,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
work
.
wait
()
work
.
wait
()
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
block_id
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
work
=
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
block_id
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
work
.
wait
()
if
block_id
+
1
==
self
.
_num_blocks
:
self
.
_L2_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
elif
block_id
!=
0
:
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
else
:
self
.
_L2_grad_norm
+=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
self
.
_L2_grad_norm
.
sqrt_
()
return
work
return
work
# NB!
# NB!
...
@@ -431,16 +442,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -431,16 +442,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
bias_correction
,
bias_correction
,
group
[
'weight_decay'
])
group
[
'weight_decay'
])
def
_do_compute_L2_grad_norm
(
self
):
partial_sum
=
torch
.
zeros
([]).
cuda
()
for
block
in
range
(
self
.
_num_blocks
):
grad_block
=
self
.
_flat_grads
[
block
*
self
.
_block_size
:(
block
+
1
)
*
self
.
_block_size
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
shard_grad_norm
=
grad_shards
[
self
.
_rank_in_group
].
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
partial_sum
+=
(
shard_grad_norm
*
shard_grad_norm
)
torch
.
distributed
.
all_reduce
(
partial_sum
,
group
=
self
.
_rs_pg
[
0
],
async_op
=
False
)
self
.
_L2_grad_norm
=
partial_sum
.
sqrt
().
item
()
def
complete_reductions
(
self
):
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
"""
...
@@ -456,49 +457,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -456,49 +457,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
if
self
.
_new_params
is
None
:
# nothing done so far, run full pipeline after reductions
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
# nothing done so far, run full pipeline after reductions
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
# do reductions, wait, complete L2, do step
work
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
self
.
_works
.
append
(
work
)
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_works
.
append
(
work
)
self
.
_wait_works
()
self
.
_do_compute_L2_grad_norm
()
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
else
:
# run full pipeline
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
else
:
# reductions done.
if
self
.
_compute_L2_grad_norm
:
self
.
_do_compute_L2_grad_norm
()
# do step
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
else
:
if
self
.
_compute_L2_grad_norm
:
self
.
_do_compute_L2_grad_norm
()
self
.
_copy_to_fp32
=
False
self
.
_copy_to_fp32
=
False
self
.
_decomp_stats
=
None
self
.
_decomp_stats
=
None
...
@@ -517,6 +483,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -517,6 +483,16 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
if
self
.
_last_step
or
not
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
self
.
_wait_works
()
self
.
_wait_works
()
# Check for overflow
# Check for overflow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment