Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3f717d95
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67f5cce2943e0f32f4e0c6b53177dc3107a7955f"
Commit
3f717d95
authored
Apr 01, 2020
by
Thor Johnsen
Browse files
Bug fix in internal pipelining
parent
17160f34
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
61 deletions
+49
-61
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+49
-61
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
3f717d95
...
@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
):
dwu_num_chunks
=
4
,
predivide
=
True
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -78,6 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -78,6 +78,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_global_scale
=
None
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
...
@@ -160,7 +161,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -160,7 +161,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_blk_st
=
[]
self
.
_blk_st
=
[]
for
i
in
range
(
self
.
_num_blk_st
):
for
i
in
range
(
self
.
_num_blk_st
):
self
.
_blk_st
.
append
(
torch
.
cuda
.
Stream
())
self
.
_blk_st
.
append
(
torch
.
cuda
.
Stream
())
self
.
_works
=
[]
import
inspect
import
inspect
if
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
if
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
:
...
@@ -197,19 +197,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -197,19 +197,20 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end
=
start
+
self
.
_block_size
end
=
start
+
self
.
_block_size
grad_block
=
flat_grads
[
start
:
end
]
grad_block
=
flat_grads
[
start
:
end
]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
grad_shards
=
[
grad_block
[
i
*
self
.
_shard_size
:(
i
+
1
)
*
self
.
_shard_size
]
for
i
in
range
(
self
.
_group_size
)]
if
self
.
_pg_supports_no_copy
:
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
,
no_copy
=
True
)
else
:
else
:
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
work
=
torch
.
distributed
.
reduce_scatter
(
grad_shards
[
self
.
_rank_in_group
],
grad_shards
,
group
=
self
.
_rs_pg
[
block_id
%
len
(
self
.
_rs_pg
)],
async_op
=
True
)
works
=
[
work
]
works
=
[
work
]
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
sliver_size
=
self
.
_shard_size
//
self
.
_num_chunks
assert
((
sliver_size
*
self
.
_num_chunks
)
==
self
.
_shard_size
),
"Shard size not a multiple of dwu_num_chunks"
works
=
[]
work
.
wait
()
work
.
wait
()
works
=
[]
chunk_size
=
self
.
_shard_size
//
self
.
_num_chunks
for
i
in
range
(
self
.
_num_chunks
):
for
i
in
range
(
self
.
_num_chunks
):
works
.
append
(
torch
.
distributed
.
all_reduce
(
grad_shards
[
self
.
_rank_in_group
][
i
*
sliver_size
:(
i
+
1
)
*
sliver_size
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
)
chunks
=
[
grad_shards
[
j
][
i
*
chunk_size
:(
i
+
1
)
*
chunk_size
]
for
j
in
range
(
self
.
_group_size
)]
work
=
torch
.
distributed
.
all_reduce
(
chunks
[
self
.
_rank_in_group
],
group
=
self
.
_ar_pg
[
i
%
len
(
self
.
_ar_pg
)],
async_op
=
True
)
works
.
append
(
work
)
if
self
.
_compute_L2_grad_norm
:
if
self
.
_compute_L2_grad_norm
:
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
...
@@ -224,7 +225,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -224,7 +225,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
torch
.
distributed
.
all_reduce
(
self
.
_L2_grad_norm
,
group
=
self
.
_rs_pg
[
0
])
self
.
_L2_grad_norm
.
sqrt_
()
self
.
_L2_grad_norm
.
sqrt_
()
return
works
for
work
in
works
:
work
.
wait
()
# NB!
# NB!
# self._global_scale is used by this method.
# self._global_scale is used by this method.
...
@@ -234,21 +236,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -234,21 +236,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
new_params_shards
=
[
new_params
[
start
+
shard_i
*
self
.
_shard_size
:
start
+
(
shard_i
+
1
)
*
self
.
_shard_size
]
for
shard_i
in
range
(
self
.
_group_size
)]
new_params_shards
=
[
new_params
[
start
+
shard_i
*
self
.
_shard_size
:
start
+
(
shard_i
+
1
)
*
self
.
_shard_size
]
for
shard_i
in
range
(
self
.
_group_size
)]
self
.
_partial_step_single_shard
(
block_id
)
self
.
_partial_step_single_shard
(
block_id
)
if
self
.
_pg_supports_no_copy
:
if
self
.
_pg_supports_no_copy
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
,
no_copy
=
True
)
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
no_copy
=
True
)
else
:
else
:
work
=
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)],
async_op
=
True
)
torch
.
distributed
.
all_gather
(
new_params_shards
,
new_params_shards
[
self
.
_rank_in_group
],
group
=
self
.
_ag_pg
[
block_id
%
len
(
self
.
_ag_pg
)])
return
work
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
def
_pipeline_block
(
self
,
block_id
,
flat_grads
,
new_params
):
works
=
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
for
work
in
works
:
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
if
work
is
not
None
:
work
.
wait
()
return
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
# handle overlapped reductions
# handle overlapped reductions
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
self
.
_grads_generated
[
param_i
]
=
True
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
:
if
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
if
self
.
_overlap_reductions
:
...
@@ -260,20 +258,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -260,20 +258,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_full_pipeline
:
if
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
if
self
.
_new_params
is
None
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
work
=
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
else
:
else
:
works
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_works
+=
works
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
def
_wait_works
(
self
):
for
work
in
self
.
_works
:
if
work
is
not
None
:
work
.
wait
()
self
.
_works
=
[]
def
set_global_scale
(
self
,
global_scale
):
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""Set global scale.
"""
"""
...
@@ -457,7 +447,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -457,7 +447,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
complete_reductions
(
self
):
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
"""
self
.
_wait_works
()
if
self
.
_last_step
:
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
# zero out gradients that have not been completed yet
...
@@ -475,8 +464,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -475,8 +464,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
works
=
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_works
+=
works
self
.
_copy_to_fp32
=
False
self
.
_copy_to_fp32
=
False
self
.
_decomp_stats
=
None
self
.
_decomp_stats
=
None
...
@@ -486,7 +474,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -486,7 +474,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
revert_step
(
self
):
def
revert_step
(
self
):
"""Revert effect of previously calling partial_step.
"""Revert effect of previously calling partial_step.
"""
"""
self
.
_wait_works
()
for
block_id
in
range
(
self
.
_num_blocks
):
for
block_id
in
range
(
self
.
_num_blocks
):
self
.
_partial_step_single_shard
(
block_id
,
undo
=
True
)
self
.
_partial_step_single_shard
(
block_id
,
undo
=
True
)
...
@@ -500,37 +487,38 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -500,37 +487,38 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
work
=
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_pipeline_block_step
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_works
.
append
(
work
)
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
0
]):
self
.
_wait_works
()
for
i
,
blk_st
in
enumerate
(
self
.
_blk_st
):
torch
.
cuda
.
current_stream
().
wait_stream
(
blk_st
)
# Check for overflow
# Store state for loss scaler calculation
# Check for overflow
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
# Store state for loss scaler calculation
if
self
.
peek_overflow
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
print
(
"Reverting step"
)
if
self
.
peek_overflow
:
self
.
revert_step
()
print
(
"Reverting step"
)
else
:
self
.
revert_step
()
# Copy self._new_params to model params
else
:
with
torch
.
no_grad
():
# Copy self._new_params to model params
param_i
=
0
with
torch
.
no_grad
():
for
group
in
self
.
param_groups
:
param_i
=
0
for
p
in
group
[
'params'
]:
for
group
in
self
.
param_groups
:
if
not
p
.
requires_grad
:
for
p
in
group
[
'params'
]:
continue
if
not
p
.
requires_grad
:
state
=
self
.
state
[
p
]
continue
if
len
(
state
)
==
0
:
state
=
self
.
state
[
p
]
state
[
'step'
]
=
0
if
len
(
state
)
==
0
:
state
[
'step'
]
+=
1
state
[
'step'
]
=
0
nels
=
p
.
numel
()
state
[
'step'
]
+=
1
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
nels
=
p
.
numel
()
p
.
set_
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
param_i
+=
1
p
.
set_
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
self
.
_new_params
=
None
param_i
+=
1
self
.
_new_params
=
None
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_blk_st
[
0
])
return
loss
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment