Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7ce2e04f
You need to sign in or sign up before continuing.
Commit
7ce2e04f
authored
Apr 27, 2018
by
Michael Carilli
Browse files
Reorganizing fp16_optimizer
parent
31cee8e7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
28 deletions
+35
-28
apex/fp16_utils/fp16_optimizer.py
apex/fp16_utils/fp16_optimizer.py
+35
-28
No files found.
apex/fp16_utils/fp16_optimizer.py
View file @
7ce2e04f
...
@@ -138,7 +138,7 @@ class FP16_Optimizer(object):
...
@@ -138,7 +138,7 @@ class FP16_Optimizer(object):
the loss scale is not recommended.
the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch
DataParallel or
DistributedDataParallel, :class:`FP16_Optimizer` should still work as
Pytorch DistributedDataParallel, :class:`FP16_Optimizer` should still work as
intended.
intended.
"""
"""
...
@@ -198,29 +198,11 @@ class FP16_Optimizer(object):
...
@@ -198,29 +198,11 @@ class FP16_Optimizer(object):
self
.
overflow
=
False
self
.
overflow
=
False
self
.
first_closure_call_this_step
=
True
self
.
first_closure_call_this_step
=
True
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that
def
__getstate__
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate).
raise
RuntimeError
(
"FP16_Optimizer should be serialized using state_dict()."
)
def
__getattribute__
(
self
,
name
):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if
name
==
'state'
:
return
self
.
optimizer
.
state
elif
name
==
'param_groups'
:
return
self
.
optimizer
.
param_groups
else
:
return
object
.
__getattribute__
(
self
,
name
)
def
__setattr__
(
self
,
name
,
value
):
def
__setstate__
if
name
==
'state'
:
raise
RuntimeError
(
"FP16_Optimizer should be deserialized using load_state_dict()."
)
self
.
optimizer
.
state
=
value
elif
name
==
'param_groups'
:
self
.
optimizer
.
param_groups
=
value
else
:
object
.
__setattr__
(
self
,
name
,
value
)
def
zero_grad
(
self
):
def
zero_grad
(
self
):
"""
"""
...
@@ -250,6 +232,10 @@ class FP16_Optimizer(object):
...
@@ -250,6 +232,10 @@ class FP16_Optimizer(object):
def
_update_scale
(
self
,
has_overflow
=
False
):
def
_update_scale
(
self
,
has_overflow
=
False
):
self
.
loss_scaler
.
update_scale
(
has_overflow
)
self
.
loss_scaler
.
update_scale
(
has_overflow
)
def
_master_params_to_model_params
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def
_model_grads_to_master_grads
(
self
):
def
_model_grads_to_master_grads
(
self
):
...
@@ -286,10 +272,6 @@ class FP16_Optimizer(object):
...
@@ -286,10 +272,6 @@ class FP16_Optimizer(object):
else
:
else
:
return
-
1
return
-
1
def
_master_params_to_model_params
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
state_dict
(
self
):
def
state_dict
(
self
):
"""
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...
@@ -540,7 +522,8 @@ class FP16_Optimizer(object):
...
@@ -540,7 +522,8 @@ class FP16_Optimizer(object):
return
None
return
None
else
:
else
:
return
None
return
None
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def
_get_loss_scale
(
self
):
def
_get_loss_scale
(
self
):
return
self
.
loss_scaler
.
loss_scale
return
self
.
loss_scaler
.
loss_scale
...
@@ -548,3 +531,27 @@ class FP16_Optimizer(object):
...
@@ -548,3 +531,27 @@ class FP16_Optimizer(object):
self
.
loss_scaler
.
cur_scale
=
value
self
.
loss_scaler
.
cur_scale
=
value
loss_scale
=
property
(
_get_loss_scale
,
_set_loss_scale
)
loss_scale
=
property
(
_get_loss_scale
,
_set_loss_scale
)
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate).
def
__getattribute__
(
self
,
name
):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if
name
==
'state'
:
return
self
.
optimizer
.
state
elif
name
==
'param_groups'
:
return
self
.
optimizer
.
param_groups
else
:
return
object
.
__getattribute__
(
self
,
name
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
'state'
:
self
.
optimizer
.
state
=
value
elif
name
==
'param_groups'
:
self
.
optimizer
.
param_groups
=
value
else
:
object
.
__setattr__
(
self
,
name
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment