Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
db8fb976
"graphbolt/src/vscode:/vscode.git/clone" did not exist on "8f11ff9b6f051eebe2f1888c5b721660ebcc62ff"
Commit
db8fb976
authored
Apr 01, 2020
by
Thor Johnsen
Browse files
Add back support for multi tensor scale flattening
parent
3f717d95
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
15 deletions
+50
-15
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+50
-15
No files found.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
db8fb976
...
@@ -107,6 +107,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -107,6 +107,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
p_i
+=
1
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_flat_mt
=
flat_mt
self
.
_grads
=
[
None
]
*
len
(
self
.
_grads_info
)
if
self
.
_flat_mt
else
None
if
self
.
_overlap_reductions
:
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
self
.
_current_block
=
self
.
_num_blocks
...
@@ -118,6 +120,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -118,6 +120,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_shard_size
=
self
.
_block_size
//
self
.
_group_size
self
.
_shard_size
=
self
.
_block_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
))
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
]).
half
().
cuda
()
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
]).
half
().
cuda
()
self
.
_new_params
=
None
self
.
_new_params
=
None
self
.
_fp32_p
=
None
self
.
_fp32_p
=
None
...
@@ -175,6 +185,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -175,6 +185,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
def
_get_flush_block
(
self
):
def
_get_flush_block
(
self
):
flush_block
=
[]
flush_block
=
[]
if
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
...
@@ -244,8 +255,30 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -244,8 +255,30 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
self
.
_pipeline_block_reductions
(
block_id
,
flat_grads
)
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
self
.
_pipeline_block_step
(
block_id
,
flat_grads
,
new_params
)
def
_flatten_grad_mt
(
self
,
scale
):
if
self
.
_flat_mt
:
grads
=
[]
flat_grads
=
[]
for
p_i
,
(
grads_info
,
grad
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_grads
)):
if
grad
is
not
None
:
grads
.
append
(
grad
)
flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]]
)
self
.
_grads
[
p_i
]
=
None
if
len
(
grads
)
>
0
:
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
grads
,
flat_grads
],
scale
)
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
grad
):
# handle overlapped reductions
# handle overlapped reductions
if
self
.
_flat_mt
:
self
.
_grads
[
param_i
]
=
grad
else
:
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
torch
.
div
(
grad
.
view
(
-
1
),
self
.
_world_size
if
self
.
_predivide
else
1.0
,
out
=
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_grads_size
])
self
.
_grads_generated
[
param_i
]
=
True
self
.
_grads_generated
[
param_i
]
=
True
if
not
self
.
_last_step
:
if
not
self
.
_last_step
:
...
@@ -253,6 +286,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -253,6 +286,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
flush_block
=
self
.
_get_flush_block
()
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
if
self
.
_full_pipeline
:
if
self
.
_full_pipeline
:
...
@@ -462,6 +496,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -462,6 +496,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# nothing done so far, run full pipeline after reductions
# nothing done so far, run full pipeline after reductions
for
inv_block_id
in
range
(
self
.
_num_blocks
):
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
if
self
.
_predivide
else
1.0
)
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
self
.
_pipeline_block_reductions
(
block_id
,
self
.
_flat_grads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment