Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
02fd7341
Commit
02fd7341
authored
May 30, 2020
by
Thor Johnsen
Browse files
Add optional accumulation step
parent
9a09107c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
52 deletions
+15
-52
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+15
-52
No files found.
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
02fd7341
...
...
@@ -94,9 +94,9 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
self
.
_is_accumulation_step
=
False
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_e5m2_allgather
=
e5m2_allgather
...
...
@@ -363,6 +363,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
set_is_accumulation_step
(
self
,
is_accumulation_step
):
self
.
_is_accumulation_step
=
is_accumulation_step
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
...
...
@@ -492,58 +494,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_grads_fp32
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
if
param
.
dtype
==
torch
.
float16
:
self
.
_grads_fp16
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
self
.
_grads_fp32
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_overlap_reductions
and
not
self
.
_last_step
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
if
not
self
.
_is_accumulation_step
:
# handle overlapped reductions
if
param
.
dtype
==
torch
.
float16
:
self
.
_grads_fp16
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
self
.
_grads_fp32
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_overlap_reductions
and
not
self
.
_last_step
:
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""
self
.
_global_scale
=
global_scale
@
property
def
global_scale
(
self
):
return
self
.
_global_scale
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_has_overflow
self
.
_has_overflow
=
False
return
has_overflow
@
property
def
peek_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_has_overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if
start
>=
0
and
start
<
end
:
out_p
=
output_params
[
start
:
end
]
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
@
property
def
L2_grad_norm
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment