Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
7ce2e04f
"...text-generation-inference.git" did not exist on "fa43fb71be75064de58784e96b16b14f7b3b4912"
Commit
7ce2e04f
authored
Apr 27, 2018
by
Michael Carilli
Browse files
Reorganizing fp16_optimizer
parent
31cee8e7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
28 deletions
+35
-28
apex/fp16_utils/fp16_optimizer.py
apex/fp16_utils/fp16_optimizer.py
+35
-28
No files found.
apex/fp16_utils/fp16_optimizer.py
View file @
7ce2e04f
...
...
@@ -138,7 +138,7 @@ class FP16_Optimizer(object):
the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch
DataParallel or
DistributedDataParallel, :class:`FP16_Optimizer` should still work as
Pytorch DistributedDataParallel, :class:`FP16_Optimizer` should still work as
intended.
"""
...
...
@@ -198,29 +198,11 @@ class FP16_Optimizer(object):
self
.
overflow
=
False
self
.
first_closure_call_this_step
=
True
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate).
def
__getattribute__
(
self
,
name
):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if
name
==
'state'
:
return
self
.
optimizer
.
state
elif
name
==
'param_groups'
:
return
self
.
optimizer
.
param_groups
else
:
return
object
.
__getattribute__
(
self
,
name
)
def
__getstate__
raise
RuntimeError
(
"FP16_Optimizer should be serialized using state_dict()."
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
'state'
:
self
.
optimizer
.
state
=
value
elif
name
==
'param_groups'
:
self
.
optimizer
.
param_groups
=
value
else
:
object
.
__setattr__
(
self
,
name
,
value
)
def
__setstate__
raise
RuntimeError
(
"FP16_Optimizer should be deserialized using load_state_dict()."
)
def
zero_grad
(
self
):
"""
...
...
@@ -250,6 +232,10 @@ class FP16_Optimizer(object):
def
_update_scale
(
self
,
has_overflow
=
False
):
self
.
loss_scaler
.
update_scale
(
has_overflow
)
def
_master_params_to_model_params
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def
_model_grads_to_master_grads
(
self
):
...
...
@@ -286,10 +272,6 @@ class FP16_Optimizer(object):
else
:
return
-
1
def
_master_params_to_model_params
(
self
):
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
state_dict
(
self
):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...
...
@@ -540,7 +522,8 @@ class FP16_Optimizer(object):
return
None
else
:
return
None
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def
_get_loss_scale
(
self
):
return
self
.
loss_scaler
.
loss_scale
...
...
@@ -548,3 +531,27 @@ class FP16_Optimizer(object):
self
.
loss_scaler
.
cur_scale
=
value
loss_scale
=
property
(
_get_loss_scale
,
_set_loss_scale
)
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate).
def
__getattribute__
(
self
,
name
):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if
name
==
'state'
:
return
self
.
optimizer
.
state
elif
name
==
'param_groups'
:
return
self
.
optimizer
.
param_groups
else
:
return
object
.
__getattribute__
(
self
,
name
)
def
__setattr__
(
self
,
name
,
value
):
if
name
==
'state'
:
self
.
optimizer
.
state
=
value
elif
name
==
'param_groups'
:
self
.
optimizer
.
param_groups
=
value
else
:
object
.
__setattr__
(
self
,
name
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment