Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
2d52809c
Unverified
Commit
2d52809c
authored
Oct 31, 2020
by
Cao Yuhang
Committed by
GitHub
Oct 31, 2020
Browse files
Add dynamic scale (#585)
* add dynamic scale * add type check of loss scale * fix lint * minor fix
parent
03214fd4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
117 additions
and
21 deletions
+117
-21
mmcv/runner/__init__.py
mmcv/runner/__init__.py
+2
-2
mmcv/runner/fp16_utils.py
mmcv/runner/fp16_utils.py
+84
-0
mmcv/runner/hooks/optimizer.py
mmcv/runner/hooks/optimizer.py
+31
-19
No files found.
mmcv/runner/__init__.py
View file @
2d52809c
...
...
@@ -6,7 +6,7 @@ from .checkpoint import (_load_checkpoint, load_checkpoint, load_state_dict,
from
.dist_utils
import
(
allreduce_grads
,
allreduce_params
,
get_dist_info
,
init_dist
,
master_only
)
from
.epoch_based_runner
import
EpochBasedRunner
,
Runner
from
.fp16_utils
import
auto_fp16
,
force_fp32
,
wrap_fp16_model
from
.fp16_utils
import
LossScaler
,
auto_fp16
,
force_fp32
,
wrap_fp16_model
from
.hooks
import
(
HOOKS
,
CheckpointHook
,
ClosureHook
,
DistSamplerSeedHook
,
EMAHook
,
Fp16OptimizerHook
,
Hook
,
IterTimerHook
,
LoggerHook
,
LrUpdaterHook
,
MlflowLoggerHook
,
OptimizerHook
,
...
...
@@ -33,5 +33,5 @@ __all__ = [
'build_optimizer'
,
'build_optimizer_constructor'
,
'IterLoader'
,
'set_random_seed'
,
'auto_fp16'
,
'force_fp32'
,
'wrap_fp16_model'
,
'Fp16OptimizerHook'
,
'SyncBuffersHook'
,
'EMAHook'
,
'build_runner'
,
'RUNNERS'
,
'allreduce_grads'
,
'allreduce_params'
'RUNNERS'
,
'allreduce_grads'
,
'allreduce_params'
,
'LossScaler'
]
mmcv/runner/fp16_utils.py
View file @
2d52809c
...
...
@@ -264,3 +264,87 @@ def patch_forward_method(func, src_type, dst_type, convert_output=True):
return
output
return
new_forward
class
LossScaler
:
"""Class that manages loss scaling in mixed precision training which
supports both dynamic or static mode.
The implementation refers to
https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
It's important to understand how :class:`LossScaler` operates.
Loss scaling is designed to combat the problem of underflowing
gradients encountered at long times when training fp16 networks.
Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients.
If overflowing gradients are encountered, :class:`FP16_Optimizer` then
skips the update step for this particular iteration/minibatch,
and :class:`LossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients
detected,:class:`LossScaler` increases the loss scale once more.
In this way :class:`LossScaler` attempts to "ride the edge" of always
using the highest loss scale possible without incurring overflow.
Args:
init_scale (float): Initial loss scale value, default: 2**32.
scale_factor (float): Factor used when adjusting the loss scale.
Default: 2.
mode (str): Loss scaling mode. 'dynamic' or 'static'
scale_window (int): Number of consecutive iterations without an
overflow to wait before increasing the loss scale. Default: 1000.
"""
def
__init__
(
self
,
init_scale
=
2
**
32
,
mode
=
'dynamic'
,
scale_factor
=
2.
,
scale_window
=
1000
):
self
.
cur_scale
=
init_scale
self
.
cur_iter
=
0
assert
mode
in
(
'dynamic'
,
'static'
),
'mode can only be dynamic or static'
self
.
mode
=
mode
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
def
has_overflow
(
self
,
params
):
"""Check if params contain overflow."""
if
self
.
mode
!=
'dynamic'
:
return
False
for
p
in
params
:
if
p
.
grad
is
not
None
and
LossScaler
.
_has_inf_or_nan
(
p
.
grad
.
data
):
return
True
return
False
def
_has_inf_or_nan
(
x
):
"""Check if params contain NaN."""
try
:
cpu_sum
=
float
(
x
.
float
().
sum
())
except
RuntimeError
as
instance
:
if
'value cannot be converted'
not
in
instance
.
args
[
0
]:
raise
return
True
else
:
if
cpu_sum
==
float
(
'inf'
)
or
cpu_sum
==
-
float
(
'inf'
)
\
or
cpu_sum
!=
cpu_sum
:
return
True
return
False
def
update_scale
(
self
,
overflow
):
"""update the current loss scale value when overflow happens."""
if
self
.
mode
!=
'dynamic'
:
return
if
overflow
:
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
else
:
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
\
self
.
scale_window
==
0
:
self
.
cur_scale
*=
self
.
scale_factor
self
.
cur_iter
+=
1
@
property
def
loss_scale
(
self
):
return
self
.
cur_scale
mmcv/runner/hooks/optimizer.py
View file @
2d52809c
...
...
@@ -6,7 +6,7 @@ from itertools import chain
from
torch.nn.utils
import
clip_grad
from
..dist_utils
import
allreduce_grads
from
..fp16_utils
import
wrap_fp16_model
from
..fp16_utils
import
LossScaler
,
wrap_fp16_model
from
.hook
import
HOOKS
,
Hook
...
...
@@ -48,7 +48,8 @@ class Fp16OptimizerHook(OptimizerHook):
Refer to https://arxiv.org/abs/1710.03740 for more details.
Args:
loss_scale (float): Scale factor multiplied with loss.
loss_scale (float | str): Scale factor multiplied with loss. If
'dynamic' is specified, then dynamic loss scaling will be used.
"""
def
__init__
(
self
,
...
...
@@ -60,8 +61,13 @@ class Fp16OptimizerHook(OptimizerHook):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
self
.
bucket_size_mb
=
bucket_size_mb
self
.
loss_scale
=
loss_scale
self
.
distributed
=
distributed
if
loss_scale
==
'dynamic'
:
self
.
loss_scaler
=
LossScaler
(
mode
=
'dynamic'
)
elif
isinstance
(
loss_scale
,
float
):
self
.
loss_scaler
=
LossScaler
(
init_scale
=
loss_scale
,
mode
=
'static'
)
else
:
raise
ValueError
(
'loss_scale must be of type float or str'
)
def
before_run
(
self
,
runner
):
"""Preparing steps before Mixed Precision Training.
...
...
@@ -100,7 +106,8 @@ class Fp16OptimizerHook(OptimizerHook):
fp16_param
.
data
.
copy_
(
fp32_param
.
data
)
def
after_train_iter
(
self
,
runner
):
"""Backward optimization steps for Mixed Precision Training.
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer `loss_scalar.py`
1. Scale the loss by a scale factor.
2. Backward the loss to obtain the gradients (fp16).
...
...
@@ -112,9 +119,10 @@ class Fp16OptimizerHook(OptimizerHook):
runner
.
model
.
zero_grad
()
runner
.
optimizer
.
zero_grad
()
# scale the loss value
scaled_loss
=
runner
.
outputs
[
'loss'
]
*
self
.
loss_scale
scaled_loss
=
runner
.
outputs
[
'loss'
]
*
self
.
loss_scale
r
.
loss_scale
scaled_loss
.
backward
()
# copy fp16 grads in the model to fp32 params in the optimizer
fp32_weights
=
[]
for
param_group
in
runner
.
optimizer
.
param_groups
:
fp32_weights
+=
param_group
[
'params'
]
...
...
@@ -122,17 +130,21 @@ class Fp16OptimizerHook(OptimizerHook):
# allreduce grads
if
self
.
distributed
:
allreduce_grads
(
fp32_weights
,
self
.
coalesce
,
self
.
bucket_size_mb
)
# scale the gradients back
for
param
in
fp32_weights
:
if
param
.
grad
is
not
None
:
param
.
grad
.
div_
(
self
.
loss_scale
)
if
self
.
grad_clip
is
not
None
:
grad_norm
=
self
.
clip_grads
(
fp32_weights
)
if
grad_norm
is
not
None
:
# Add grad norm to the logger
runner
.
log_buffer
.
update
({
'grad_norm'
:
float
(
grad_norm
)},
runner
.
outputs
[
'num_samples'
])
# update fp32 params
runner
.
optimizer
.
step
()
# copy fp32 params to the fp16 model
self
.
copy_params_to_fp16
(
runner
.
model
,
fp32_weights
)
has_overflow
=
self
.
loss_scaler
.
has_overflow
(
fp32_weights
)
# if has overflow, skip this iteration
if
not
has_overflow
:
# scale the gradients back
for
param
in
fp32_weights
:
if
param
.
grad
is
not
None
:
param
.
grad
.
div_
(
self
.
loss_scaler
.
loss_scale
)
if
self
.
grad_clip
is
not
None
:
self
.
clip_grads
(
fp32_weights
)
# update fp32 params
runner
.
optimizer
.
step
()
# copy fp32 params to the fp16 model
self
.
copy_params_to_fp16
(
runner
.
model
,
fp32_weights
)
self
.
loss_scaler
.
update_scale
(
has_overflow
)
if
has_overflow
:
runner
.
logger
.
warning
(
'Check overflow, downscale loss scale '
f
'to
{
self
.
loss_scaler
.
cur_scale
}
'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment