Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
14a70942
Unverified
Commit
14a70942
authored
Mar 16, 2022
by
Frank Lee
Committed by
GitHub
Mar 16, 2022
Browse files
fixed fp16 optimizer none grad bug (#432)
parent
fce9432f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
104 deletions
+5
-104
colossalai/amp/naive_amp/_fp16_optimizer.py
colossalai/amp/naive_amp/_fp16_optimizer.py
+5
-104
No files found.
colossalai/amp/naive_amp/_fp16_optimizer.py
View file @
14a70942
...
...
@@ -39,106 +39,6 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
that_
.
copy_
(
this_
)
class
DynamicGradScaler
:
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
,
max_scale
:
int
=
None
,
verbose
:
bool
=
False
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatTensor
([
growth_factor
])
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatTensor
([
backoff_factor
])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
self
.
growth_interval
=
growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert
hysteresis
>
0
self
.
hysteresis
=
hysteresis
if
max_scale
is
not
None
:
assert
max_scale
>
1
and
initial_scale
<=
max_scale
self
.
_max_scale
=
max_scale
# Trackers.
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
self
.
_logger
=
get_dist_logger
()
self
.
verbose
=
verbose
@
property
def
scale
(
self
):
return
self
.
_scale
@
property
def
inv_scale
(
self
):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
def
update
(
self
,
found_inf
):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if
found_inf
:
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
-=
1
# Now if we are out of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'overflow occurs, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
# If we have had enough consequitive intervals with no nan/inf:
if
self
.
_growth_tracker
==
self
.
growth_interval
:
# Reset the tracker and hysteresis trackers,
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
if
self
.
_max_scale
is
not
None
and
self
.
_scale
>=
self
.
_max_scale
:
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'Current loss scale
{
self
.
_scale
}
has reached the max scale
{
self
.
_max_scale
}
allowed'
,
ranks
=
[
0
])
else
:
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
if
self
.
verbose
:
self
.
_logger
.
info
(
f
'no consecutive overflow, loss scale is adjusted to
{
self
.
_scale
}
'
,
ranks
=
[
0
])
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'max_scale'
]
=
self
.
_max_scale
state_dict
[
'scale'
]
=
self
.
_scale
state_dict
[
'growth_tracker'
]
=
self
.
_growth_tracker
state_dict
[
'hysteresis_tracker'
]
=
self
.
_hysteresis_tracker
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
self
.
_scale
=
state_dict
[
'scale'
].
cuda
(
torch
.
cuda
.
current_device
())
self
.
_growth_tracker
=
state_dict
[
'growth_tracker'
]
self
.
_hysteresis_tracker
=
state_dict
[
'hysteresis_tracker'
]
self
.
_max_scale
=
state_dict
[
'max_scale'
]
class
FP16Optimizer
(
Optimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
...
...
@@ -284,7 +184,7 @@ class FP16Optimizer(Optimizer):
# check for overflow
for
group
in
self
.
_optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
if
has_inf_or_nan
(
p
.
grad
):
if
p
.
grad
is
not
None
and
has_inf_or_nan
(
p
.
grad
):
self
.
_found_overflow
.
fill_
(
1.0
)
break
...
...
@@ -316,6 +216,7 @@ class FP16Optimizer(Optimizer):
# This only needs to be done for the float16 group.
for
fp16_param_group
,
fp32_master_param_group
in
zip
(
self
.
_fp16_param_groups
,
self
.
_fp32_master_param_groups
):
for
fp16_param
,
fp32_param
in
zip
(
fp16_param_group
,
fp32_master_param_group
):
if
fp16_param
.
grad
is
not
None
:
fp32_param
.
grad
=
fp16_param
.
grad
.
float
()
# clear unneeded grad on fp16 param
fp16_param
.
grad
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment