Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
57ad1840
Unverified
Commit
57ad1840
authored
Feb 05, 2019
by
mcarilli
Committed by
GitHub
Feb 05, 2019
Browse files
Merge pull request #123 from donglixp/patch-1
apex.optimizers.FP16_Optimizer: add state_dict() and load_state_dict()
parents
45537d34
475fca23
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
0 deletions
+66
-0
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+66
-0
No files found.
apex/optimizers/fp16_optimizer.py
View file @
57ad1840
...
@@ -214,3 +214,69 @@ class FP16_Optimizer(object):
...
@@ -214,3 +214,69 @@ class FP16_Optimizer(object):
self
.
optimizer
.
param_groups
=
value
self
.
optimizer
.
param_groups
=
value
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
def
state_dict
(
self
):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict
=
{}
state_dict
[
'dynamic_loss_scale'
]
=
self
.
dynamic_loss_scale
state_dict
[
'cur_scale'
]
=
self
.
cur_scale
state_dict
[
'cur_iter'
]
=
self
.
cur_iter
if
state_dict
[
'dynamic_loss_scale'
]:
state_dict
[
'last_overflow_iter'
]
=
self
.
last_overflow_iter
state_dict
[
'scale_factor'
]
=
self
.
scale_factor
state_dict
[
'scale_window'
]
=
self
.
scale_window
state_dict
[
'optimizer_state_dict'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'fp32_groups_flat'
]
=
self
.
fp32_groups_flat
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self
.
dynamic_loss_scale
=
state_dict
[
'dynamic_loss_scale'
]
self
.
cur_scale
=
state_dict
[
'cur_scale'
]
self
.
cur_iter
=
state_dict
[
'cur_iter'
]
if
state_dict
[
'dynamic_loss_scale'
]:
self
.
last_overflow_iter
=
state_dict
[
'last_overflow_iter'
]
self
.
scale_factor
=
state_dict
[
'scale_factor'
]
self
.
scale_window
=
state_dict
[
'scale_window'
]
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer_state_dict'
])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
current
.
data
.
copy_
(
saved
.
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment