Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
62ce27d2
Unverified
Commit
62ce27d2
authored
Mar 08, 2019
by
ngimel
Committed by
GitHub
Mar 08, 2019
Browse files
Merge pull request #182 from FDecaYed/deyuf/update_norm
Remove LoadLibrary half norm hack
parents
e3053736
2f0bf594
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
272 additions
and
291 deletions
+272
-291
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+272
-291
No files found.
apex/optimizers/fp16_optimizer.py
View file @
62ce27d2
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
ctypes
class
FP16_Optimizer
(
object
):
stashed_err
=
None
"""
try
:
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
lib
=
ctypes
.
cdll
.
LoadLibrary
(
None
)
Designed only to wrap apex.optimizers.FusedAdam.
lib
.
THCudaHalfTensor_normall
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
Refer to apex.fp16_utils documents for more information.
lib
.
THCudaHalfTensor_normall
.
restype
=
ctypes
.
c_float
def
fused_norm
(
input
):
Example::
if
input
.
type
()
==
'torch.cuda.HalfTensor'
:
# 16384 is half 2 if you stare at it long enough
model = torch.nn.Linear(D_in, D_out).cuda().half()
return
lib
.
THCudaHalfTensor_normall
(
torch
.
cuda
.
_state_cdata
,
optimizer = apex.optimizers.FusedAdam(model.parameters())
input
.
_cdata
,
16384
)
# Name the FP16_Optimizer instance to replace the existing optimizer
else
:
# (recommended but not required):
return
input
.
norm
()
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
except
TypeError
as
err
:
...
stashed_err
=
err
# loss.backward() becomes:
def
fused_norm
(
input
):
optimizer.backward(loss)
raise
RuntimeError
(
"Failed to create fused_norm. This may happen on Windows "
...
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was "
,
Example with dynamic loss scaling::
stashed_err
)
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
class
FP16_Optimizer
(
object
):
# optional arg to control dynamic loss scaling behavior
"""
# dynamic_loss_args={'scale_window' : 500})
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
# Usually, dynamic_loss_args is not necessary.
Designed only to wrap apex.optimizers.FusedAdam.
"""
Refer to apex.fp16_utils documents for more information.
def
__init__
(
self
,
Example::
init_optimizer
,
static_loss_scale
=
1.0
,
model = torch.nn.Linear(D_in, D_out).cuda().half()
dynamic_loss_scale
=
False
,
optimizer = apex.optimizers.FusedAdam(model.parameters())
dynamic_loss_args
=
None
,
# Name the FP16_Optimizer instance to replace the existing optimizer
verbose
=
True
):
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
# The fused optimizer does all the work. We need this layer for two reason:
...
# 1. maintain same user API from apex.fp16_utils
# loss.backward() becomes:
# 2. keep common stuff here in case we need to add new fused optimizer later
optimizer.backward(loss)
...
# differences from apex.fp16_utils:
# - assume all model params in fp16
Example with dynamic loss scaling::
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
...
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
if
not
torch
.
cuda
.
is_available
:
# optional arg to control dynamic loss scaling behavior
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
# dynamic_loss_args={'scale_window' : 500})
self
.
optimizer
=
init_optimizer
# Usually, dynamic_loss_args is not necessary.
"""
# param flattened by groups
self
.
fp16_groups
=
[]
def
__init__
(
self
,
self
.
fp16_groups_flat
=
[]
init_optimizer
,
self
.
fp32_groups_flat
=
[]
static_loss_scale
=
1.0
,
dynamic_loss_scale
=
False
,
# loop to deal with groups
dynamic_loss_args
=
None
,
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
verbose
=
True
):
# push this group to list before modify
self
.
fp16_groups
.
append
(
param_group
[
'params'
])
# The fused optimizer does all the work. We need this layer for two reason:
# init fp16 weight buffer, flattened
# 1. maintain same user API from apex.fp16_utils
self
.
fp16_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_groups
[
i
]]))
# 2. keep common stuff here in case we need to add new fused optimizer later
# set model fp16 weight to slices of flattened buffer
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
# differences from apex.fp16_utils:
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
# - assume all model params in fp16
p
.
data
=
q
.
data
# - assume all params requires grad
# init master weight, flattened
# - flat by groups, not keeping state. TODO: remove state explicitly?
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
())
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
# modify optimizer of have flat master weight
if
not
torch
.
cuda
.
is_available
:
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
self
.
optimizer
=
init_optimizer
# we may have a way of fusing dynamic scale. Do not support for now
# param flattened by groups
if
dynamic_loss_scale
:
self
.
fp16_groups
=
[]
if
dynamic_loss_args
is
not
None
:
self
.
fp16_groups_flat
=
[]
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
self
.
fp32_groups_flat
=
[]
self
.
dynamic_loss_scale
=
True
self
.
cur_scale
=
2
**
16
# loop to deal with groups
self
.
cur_iter
=
0
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
self
.
last_overflow_iter
=
-
1
# push this group to list before modify
self
.
scale_factor
=
2
self
.
fp16_groups
.
append
(
param_group
[
'params'
])
self
.
scale_window
=
1000
# init fp16 weight buffer, flattened
else
:
self
.
fp16_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_groups
[
i
]]))
self
.
dynamic_loss_scale
=
False
# set model fp16 weight to slices of flattened buffer
self
.
cur_iter
=
0
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
self
.
cur_scale
=
static_loss_scale
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
def
zero_grad
(
self
,
set_grads_to_None
=
True
):
# init master weight, flattened
"""
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
())
Zero FP16 parameter grads.
# modify optimizer of have flat master weight
"""
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
# FP32 grad should never exist.
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
# For speed, set model fp16 grad to None by default
for
group
in
self
.
fp16_groups
:
# we may have a way of fusing dynamic scale. Do not support for now
for
p
in
group
:
if
dynamic_loss_scale
:
if
set_grads_to_None
:
if
dynamic_loss_args
is
not
None
:
p
.
grad
=
None
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
else
:
self
.
dynamic_loss_scale
=
True
if
p
.
grad
is
not
None
:
self
.
cur_scale
=
2
**
16
p
.
grad
.
detach_
()
self
.
cur_iter
=
0
p
.
grad
.
zero_
()
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
2
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
self
.
scale_window
=
1000
"""
else
:
Compute fp16 grad norm for later clipping(fused with update).
self
.
dynamic_loss_scale
=
False
Internal accumulated in fp32.
self
.
cur_iter
=
0
Also fused in NaN check. Possibly other reduction needed for grad.
self
.
cur_scale
=
static_loss_scale
Args:
def
zero_grad
(
self
,
set_grads_to_None
=
True
):
fp16_grads_flat (tensor): fp16 grad flattened
"""
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
Zero FP16 parameter grads.
infinity norm.
"""
# FP32 grad should never exist.
Returns:
# For speed, set model fp16 grad to None by default
Total norm of the current fp16 gradients (viewed as a single vector).
for
group
in
self
.
fp16_groups
:
Returns -1 if the most recently computed fp16 gradients overflowed
for
p
in
group
:
"""
if
set_grads_to_None
:
# TODO: Not most efficient with copy to cpu and sync
p
.
grad
=
None
# only support 2-norm now
else
:
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
if
p
.
grad
is
not
None
:
try
:
p
.
grad
.
detach_
()
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
,
2.0
,
dtype
=
torch
.
float32
))
p
.
grad
.
zero_
()
except
TypeError
as
err
:
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
.
float
(),
2.0
))
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
"""
return
-
1
Compute fp16 grad norm for later clipping(fused with update).
else
:
Internal accumulated in fp32.
return
norm
Also fused in NaN check. Possibly other reduction needed for grad.
def
step
(
self
,
closure
=
None
):
Args:
"""
fp16_grads_flat (tensor): fp16 grad flattened
Not supporting closure.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
"""
infinity norm.
# First compute norm for all group so we know if there is overflow
grads_groups_flat
=
[]
Returns:
norm_groups
=
[]
Total norm of the current fp16 gradients (viewed as a single vector).
skip
=
False
Returns -1 if the most recently computed fp16 gradients overflowed
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
"""
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
# only support 2-norm now
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
norm
=
float
(
fused_norm
(
fp16_grads_flat
))
skip
=
True
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
return
-
1
if
skip
:
else
:
self
.
_update_scale
(
skip
)
return
norm
return
def
step
(
self
,
closure
=
None
):
# norm is in fact norm*cur_scale
"""
self
.
optimizer
.
step
(
grads
=
[[
g
]
for
g
in
grads_groups_flat
],
Not supporting closure.
output_params
=
[[
p
]
for
p
in
self
.
fp16_groups_flat
],
"""
scale
=
self
.
cur_scale
,
# First compute norm for all group so we know if there is overflow
grad_norms
=
norm_groups
)
grads_groups_flat
=
[]
norm_groups
=
[]
# TODO: we probably don't need this? just to be safe
skip
=
False
for
i
in
range
(
len
(
norm_groups
)):
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
p
.
data
=
q
.
data
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
skip
=
True
self
.
_update_scale
(
False
)
return
if
skip
:
self
.
_update_scale
(
skip
)
def
backward
(
self
,
loss
):
return
"""
:attr:`backward` performs the following steps:
# norm is in fact norm*cur_scale
self
.
optimizer
.
step
(
grads
=
[[
g
]
for
g
in
grads_groups_flat
],
1. fp32_loss = loss.float()
output_params
=
[[
p
]
for
p
in
self
.
fp16_groups_flat
],
2. scaled_loss = fp32_loss*loss_scale
scale
=
self
.
cur_scale
,
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
grad_norms
=
norm_groups
)
"""
scaled_loss
=
(
loss
.
float
())
*
self
.
cur_scale
# TODO: we probably don't need this? just to be safe
scaled_loss
.
backward
()
for
i
in
range
(
len
(
norm_groups
)):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
def
_update_scale
(
self
,
skip
):
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
if
self
.
dynamic_loss_scale
:
p
.
data
=
q
.
data
if
skip
:
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
self
.
_update_scale
(
False
)
print
(
"Using dynamic loss scale of"
,
self
.
cur_scale
)
return
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
def
backward
(
self
,
loss
):
else
:
"""
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
self
.
scale_window
==
0
:
:attr:`backward` performs the following steps:
self
.
cur_scale
*=
self
.
scale_factor
else
:
1. fp32_loss = loss.float()
if
skip
:
2. scaled_loss = fp32_loss*loss_scale
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
print
(
"Using static loss scale of"
,
self
.
cur_scale
)
"""
self
.
cur_iter
+=
1
scaled_loss
=
(
loss
.
float
())
*
self
.
cur_scale
return
scaled_loss
.
backward
()
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def
_update_scale
(
self
,
skip
):
def
_get_state
(
self
):
if
self
.
dynamic_loss_scale
:
return
self
.
optimizer
.
state
if
skip
:
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
def
_set_state
(
self
,
value
):
print
(
"Using dynamic loss scale of"
,
self
.
cur_scale
)
self
.
optimizer
.
state
=
value
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
state
=
property
(
_get_state
,
_set_state
)
else
:
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
self
.
scale_window
==
0
:
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
self
.
cur_scale
*=
self
.
scale_factor
# (for example, to adjust the learning rate)
else
:
def
_get_param_groups
(
self
):
if
skip
:
return
self
.
optimizer
.
param_groups
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
print
(
"Using static loss scale of"
,
self
.
cur_scale
)
def
_set_param_groups
(
self
,
value
):
self
.
cur_iter
+=
1
self
.
optimizer
.
param_groups
=
value
return
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def
_get_state
(
self
):
def
state_dict
(
self
):
return
self
.
optimizer
.
state
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
def
_set_state
(
self
,
value
):
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
self
.
optimizer
.
state
=
value
of the contained Pytorch optimizer.
Example::
state
=
property
(
_get_state
,
_set_state
)
checkpoint = {}
checkpoint['model'] = model.state_dict()
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
checkpoint['optimizer'] = optimizer.state_dict()
# (for example, to adjust the learning rate)
torch.save(checkpoint, "saved.pth")
def
_get_param_groups
(
self
):
"""
return
self
.
optimizer
.
param_groups
state_dict
=
{}
state_dict
[
'dynamic_loss_scale'
]
=
self
.
dynamic_loss_scale
def
_set_param_groups
(
self
,
value
):
state_dict
[
'cur_scale'
]
=
self
.
cur_scale
self
.
optimizer
.
param_groups
=
value
state_dict
[
'cur_iter'
]
=
self
.
cur_iter
if
state_dict
[
'dynamic_loss_scale'
]:
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
state_dict
[
'last_overflow_iter'
]
=
self
.
last_overflow_iter
state_dict
[
'scale_factor'
]
=
self
.
scale_factor
def
state_dict
(
self
):
state_dict
[
'scale_window'
]
=
self
.
scale_window
"""
state_dict
[
'optimizer_state_dict'
]
=
self
.
optimizer
.
state_dict
()
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
state_dict
[
'fp32_groups_flat'
]
=
self
.
fp32_groups_flat
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
return
state_dict
of the contained Pytorch optimizer.
Example::
def
load_state_dict
(
self
,
state_dict
):
checkpoint = {}
"""
checkpoint['model'] = model.state_dict()
Loads a state_dict created by an earlier call to state_dict().
checkpoint['optimizer'] = optimizer.state_dict()
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
torch.save(checkpoint, "saved.pth")
whose parameters in turn came from ``model``, it is expected that the user
"""
will call ``model.load_state_dict()`` before
state_dict
=
{}
``fp16_optimizer_instance.load_state_dict()`` is called.
state_dict
[
'dynamic_loss_scale'
]
=
self
.
dynamic_loss_scale
Example::
state_dict
[
'cur_scale'
]
=
self
.
cur_scale
model = torch.nn.Linear(D_in, D_out).cuda().half()
state_dict
[
'cur_iter'
]
=
self
.
cur_iter
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
if
state_dict
[
'dynamic_loss_scale'
]:
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
state_dict
[
'last_overflow_iter'
]
=
self
.
last_overflow_iter
...
state_dict
[
'scale_factor'
]
=
self
.
scale_factor
checkpoint = torch.load("saved.pth")
state_dict
[
'scale_window'
]
=
self
.
scale_window
model.load_state_dict(checkpoint['model'])
state_dict
[
'optimizer_state_dict'
]
=
self
.
optimizer
.
state_dict
()
optimizer.load_state_dict(checkpoint['optimizer'])
state_dict
[
'fp32_groups_flat'
]
=
self
.
fp32_groups_flat
"""
return
state_dict
# I think it should actually be ok to reload the optimizer before the model.
self
.
dynamic_loss_scale
=
state_dict
[
'dynamic_loss_scale'
]
def
load_state_dict
(
self
,
state_dict
):
self
.
cur_scale
=
state_dict
[
'cur_scale'
]
"""
self
.
cur_iter
=
state_dict
[
'cur_iter'
]
Loads a state_dict created by an earlier call to state_dict().
if
state_dict
[
'dynamic_loss_scale'
]:
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
self
.
last_overflow_iter
=
state_dict
[
'last_overflow_iter'
]
whose parameters in turn came from ``model``, it is expected that the user
self
.
scale_factor
=
state_dict
[
'scale_factor'
]
will call ``model.load_state_dict()`` before
self
.
scale_window
=
state_dict
[
'scale_window'
]
``fp16_optimizer_instance.load_state_dict()`` is called.
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer_state_dict'
])
Example::
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
model = torch.nn.Linear(D_in, D_out).cuda().half()
# The optimizer's hyperparameters and internal buffers are also up to date.
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
# out of date. There are two options.
...
# 1: Refresh the master params from the model's fp16 params.
checkpoint = torch.load("saved.pth")
# This requires less storage but incurs precision loss.
model.load_state_dict(checkpoint['model'])
# 2: Save and restore the fp32 master copies separately.
optimizer.load_state_dict(checkpoint['optimizer'])
# We choose option 2.
"""
#
# I think it should actually be ok to reload the optimizer before the model.
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
self
.
dynamic_loss_scale
=
state_dict
[
'dynamic_loss_scale'
]
# of their associated parameters, because it's possible those buffers might not exist yet in
self
.
cur_scale
=
state_dict
[
'cur_scale'
]
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
self
.
cur_iter
=
state_dict
[
'cur_iter'
]
# constructed in the same way as the one whose state_dict we are loading, the same master params
if
state_dict
[
'dynamic_loss_scale'
]:
# are guaranteed to exist, so we can just copy_() from the saved master params.
self
.
last_overflow_iter
=
state_dict
[
'last_overflow_iter'
]
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
self
.
scale_factor
=
state_dict
[
'scale_factor'
]
current
.
data
.
copy_
(
saved
.
data
)
self
.
scale_window
=
state_dict
[
'scale_window'
]
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer_state_dict'
])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
current
.
data
.
copy_
(
saved
.
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment