Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fc6c5a25
Commit
fc6c5a25
authored
Apr 10, 2019
by
Michael Carilli
Browse files
some cleanup
parent
683b6e0e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
115 deletions
+77
-115
apex/amp/_initialize.py
apex/amp/_initialize.py
+1
-21
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+72
-83
apex/amp/handle.py
apex/amp/handle.py
+4
-11
No files found.
apex/amp/_initialize.py
View file @
fc6c5a25
...
@@ -107,22 +107,6 @@ def check_optimizers(optimizers):
...
@@ -107,22 +107,6 @@ def check_optimizers(optimizers):
"on the specified opt_level (and optional overridden properties)."
)
"on the specified opt_level (and optional overridden properties)."
)
def
wrap_fused_adam
(
optimizer
,
properties
):
msg
=
'Currently, the usage of FusedAdam is restricted to '
\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '
\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert
properties
.
master_weights
is
True
,
msg
assert
properties
.
cast_model_type
is
torch
.
float16
,
msg
assert
(
properties
.
keep_batchnorm_fp32
is
False
or
properties
.
keep_batchnorm_fp32
is
None
),
msg
if
properties
.
loss_scale
==
"dynamic"
:
return
FP16_Optimizer_for_fused
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
return
FP16_Optimizer_for_fused
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
from
.amp
import
init
as
amp_init
...
@@ -184,10 +168,6 @@ def _initialize(models, optimizers, properties, num_losses=1):
...
@@ -184,10 +168,6 @@ def _initialize(models, optimizers, properties, num_losses=1):
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
for
i
,
optimizer
in
enumerate
(
optimizers
):
for
i
,
optimizer
in
enumerate
(
optimizers
):
# Still need to special case this for the first pass
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
else
:
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
_amp_state
.
loss_scalers
=
[]
_amp_state
.
loss_scalers
=
[]
...
...
apex/amp/_process_optimizer.py
View file @
fc6c5a25
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
from
._amp_state
import
maybe_print
import
torch
import
torch
from
..optimizers
import
FusedAdam
class
AmpOptimizerState
(
object
):
class
AmpOptimizerState
(
object
):
...
@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
...
@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
self
.
load_state_dict
(
self
.
state_dict
())
self
.
load_state_dict
(
self
.
state_dict
())
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
):
# This is a lot of python overhead...
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
prepare_backward_with_master_weights
(
self
):
def
prepare_backward_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
...
@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
...
@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads
)
preexisting_fp32_grads
)
# fp32 params can be treated as they would be in the "no_master_weights" case.
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale
=
[]
post_backward_models_are_masters
(
grads_needing_unscale_with_stash
=
[]
scaler
,
stashed
=
[]
stash
.
all_fp32_from_fp32_params
,
for
param
,
stashed_grad
in
zip
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
)
stash
.
all_fp32_from_fp32_grad_stash
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None:
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stash
.
all_fp32_from_fp32_grad_stash
)):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
None
def
lazy_init_no_master_weights
(
self
):
def
lazy_init_no_master_weights
(
self
):
...
@@ -206,37 +214,7 @@ def post_backward_no_master_weights(self, scaler):
...
@@ -206,37 +214,7 @@ def post_backward_no_master_weights(self, scaler):
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
for
params
,
stashed_grads
in
split_types
:
for
params
,
stashed_grads
in
split_types
:
# This is a lot of python overhead...
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
_master_params_to_model_params
(
self
):
def
_master_params_to_model_params
(
self
):
...
@@ -283,6 +261,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -283,6 +261,7 @@ def _process_optimizer(optimizer, properties):
optimizer
.
_master_params_to_model_params
=
types
.
MethodType
(
optimizer
.
_master_params_to_model_params
=
types
.
MethodType
(
_master_params_to_model_params
,
optimizer
)
_master_params_to_model_params
,
optimizer
)
if
not
isinstance
(
optimizer
,
FusedAdam
):
old_step
=
optimizer
.
step
old_step
=
optimizer
.
step
def
new_step
(
self
):
def
new_step
(
self
):
retval
=
old_step
()
retval
=
old_step
()
...
@@ -313,18 +292,28 @@ def _process_optimizer(optimizer, properties):
...
@@ -313,18 +292,28 @@ def _process_optimizer(optimizer, properties):
param
.
grad
=
None
param
.
grad
=
None
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_fused
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights_fused
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights
,
optimizer
)
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
post_backward_with_master_weights
,
optimizer
)
else
:
else
:
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
lazy_init_no_master_weights
,
optimizer
)
lazy_init_no_master_weights
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_fused
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_fused
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights
,
optimizer
)
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
post_backward_no_master_weights
,
optimizer
)
...
...
apex/amp/handle.py
View file @
fc6c5a25
...
@@ -6,8 +6,6 @@ from . import utils
...
@@ -6,8 +6,6 @@ from . import utils
from
.opt
import
OptimWrapper
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...
@@ -82,11 +80,6 @@ def scale_loss(loss,
...
@@ -82,11 +80,6 @@ def scale_loss(loss,
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
):
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
):
optimizers
=
[
optimizers
]
optimizers
=
[
optimizers
]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
loss_scale
=
optimizers
.
cur_scale
else
:
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scale
=
loss_scaler
.
loss_scale
()
loss_scale
=
loss_scaler
.
loss_scale
()
...
@@ -113,8 +106,8 @@ def scale_loss(loss,
...
@@ -113,8 +106,8 @@ def scale_loss(loss,
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
else
:
else
:
# FusedAdam and FusedSGD
will
take care of unscaling as part of their step() methods.
# FusedAdam and FusedSGD
may
take care of unscaling as part of their step() methods.
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
#
if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler
.
clear_overflow_state
()
loss_scaler
.
clear_overflow_state
()
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_post_amp_backward
(
loss_scaler
)
optimizer
.
_post_amp_backward
(
loss_scaler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment