Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
62ce27d2
You need to sign in or sign up before continuing.
Unverified
Commit
62ce27d2
authored
Mar 08, 2019
by
ngimel
Committed by
GitHub
Mar 08, 2019
Browse files
Merge pull request #182 from FDecaYed/deyuf/update_norm
Remove LoadLibrary half norm hack
parents
e3053736
2f0bf594
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
272 additions
and
291 deletions
+272
-291
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+272
-291
No files found.
apex/optimizers/fp16_optimizer.py
View file @
62ce27d2
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
ctypes
class
FP16_Optimizer
(
object
):
stashed_err
=
None
"""
try
:
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
lib
=
ctypes
.
cdll
.
LoadLibrary
(
None
)
Designed only to wrap apex.optimizers.FusedAdam.
lib
.
THCudaHalfTensor_normall
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
Refer to apex.fp16_utils documents for more information.
lib
.
THCudaHalfTensor_normall
.
restype
=
ctypes
.
c_float
def
fused_norm
(
input
):
Example::
if
input
.
type
()
==
'torch.cuda.HalfTensor'
:
# 16384 is half 2 if you stare at it long enough
model = torch.nn.Linear(D_in, D_out).cuda().half()
return
lib
.
THCudaHalfTensor_normall
(
torch
.
cuda
.
_state_cdata
,
optimizer = apex.optimizers.FusedAdam(model.parameters())
input
.
_cdata
,
16384
)
# Name the FP16_Optimizer instance to replace the existing optimizer
else
:
# (recommended but not required):
return
input
.
norm
()
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
except
TypeError
as
err
:
...
stashed_err
=
err
# loss.backward() becomes:
def
fused_norm
(
input
):
optimizer.backward(loss)
raise
RuntimeError
(
"Failed to create fused_norm. This may happen on Windows "
...
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was "
,
Example with dynamic loss scaling::
stashed_err
)
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
class
FP16_Optimizer
(
object
):
# optional arg to control dynamic loss scaling behavior
"""
# dynamic_loss_args={'scale_window' : 500})
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
# Usually, dynamic_loss_args is not necessary.
Designed only to wrap apex.optimizers.FusedAdam.
"""
Refer to apex.fp16_utils documents for more information.
def
__init__
(
self
,
Example::
init_optimizer
,
static_loss_scale
=
1.0
,
model = torch.nn.Linear(D_in, D_out).cuda().half()
dynamic_loss_scale
=
False
,
optimizer = apex.optimizers.FusedAdam(model.parameters())
dynamic_loss_args
=
None
,
# Name the FP16_Optimizer instance to replace the existing optimizer
verbose
=
True
):
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
# The fused optimizer does all the work. We need this layer for two reason:
...
# 1. maintain same user API from apex.fp16_utils
# loss.backward() becomes:
# 2. keep common stuff here in case we need to add new fused optimizer later
optimizer.backward(loss)
...
# differences from apex.fp16_utils:
# - assume all model params in fp16
Example with dynamic loss scaling::
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
...
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
if
not
torch
.
cuda
.
is_available
:
# optional arg to control dynamic loss scaling behavior
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
# dynamic_loss_args={'scale_window' : 500})
self
.
optimizer
=
init_optimizer
# Usually, dynamic_loss_args is not necessary.
"""
# param flattened by groups
self
.
fp16_groups
=
[]
def
__init__
(
self
,
self
.
fp16_groups_flat
=
[]
init_optimizer
,
self
.
fp32_groups_flat
=
[]
static_loss_scale
=
1.0
,
dynamic_loss_scale
=
False
,
# loop to deal with groups
dynamic_loss_args
=
None
,
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
verbose
=
True
):
# push this group to list before modify
self
.
fp16_groups
.
append
(
param_group
[
'params'
])
# The fused optimizer does all the work. We need this layer for two reason:
# init fp16 weight buffer, flattened
# 1. maintain same user API from apex.fp16_utils
self
.
fp16_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_groups
[
i
]]))
# 2. keep common stuff here in case we need to add new fused optimizer later
# set model fp16 weight to slices of flattened buffer
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
# differences from apex.fp16_utils:
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
# - assume all model params in fp16
p
.
data
=
q
.
data
# - assume all params requires grad
# init master weight, flattened
# - flat by groups, not keeping state. TODO: remove state explicitly?
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
())
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
# modify optimizer of have flat master weight
if
not
torch
.
cuda
.
is_available
:
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
self
.
optimizer
=
init_optimizer
# we may have a way of fusing dynamic scale. Do not support for now
# param flattened by groups
if
dynamic_loss_scale
:
self
.
fp16_groups
=
[]
if
dynamic_loss_args
is
not
None
:
self
.
fp16_groups_flat
=
[]
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
self
.
fp32_groups_flat
=
[]
self
.
dynamic_loss_scale
=
True
self
.
cur_scale
=
2
**
16
# loop to deal with groups
self
.
cur_iter
=
0
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
self
.
last_overflow_iter
=
-
1
# push this group to list before modify
self
.
scale_factor
=
2
self
.
fp16_groups
.
append
(
param_group
[
'params'
])
self
.
scale_window
=
1000
# init fp16 weight buffer, flattened
else
:
self
.
fp16_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_groups
[
i
]]))
self
.
dynamic_loss_scale
=
False
# set model fp16 weight to slices of flattened buffer
self
.
cur_iter
=
0
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
self
.
cur_scale
=
static_loss_scale
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
def
zero_grad
(
self
,
set_grads_to_None
=
True
):
# init master weight, flattened
"""
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
())
Zero FP16 parameter grads.
# modify optimizer of have flat master weight
"""
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
# FP32 grad should never exist.
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
# For speed, set model fp16 grad to None by default
for
group
in
self
.
fp16_groups
:
# we may have a way of fusing dynamic scale. Do not support for now
for
p
in
group
:
if
dynamic_loss_scale
:
if
set_grads_to_None
:
if
dynamic_loss_args
is
not
None
:
p
.
grad
=
None
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
else
:
self
.
dynamic_loss_scale
=
True
if
p
.
grad
is
not
None
:
self
.
cur_scale
=
2
**
16
p
.
grad
.
detach_
()
self
.
cur_iter
=
0
p
.
grad
.
zero_
()
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
2
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
self
.
scale_window
=
1000
"""
else
:
Compute fp16 grad norm for later clipping(fused with update).
self
.
dynamic_loss_scale
=
False
Internal accumulated in fp32.
self
.
cur_iter
=
0
Also fused in NaN check. Possibly other reduction needed for grad.
self
.
cur_scale
=
static_loss_scale
Args:
def
zero_grad
(
self
,
set_grads_to_None
=
True
):
fp16_grads_flat (tensor): fp16 grad flattened
"""
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
Zero FP16 parameter grads.
infinity norm.
"""
# FP32 grad should never exist.
Returns:
# For speed, set model fp16 grad to None by default
Total norm of the current fp16 gradients (viewed as a single vector).
for
group
in
self
.
fp16_groups
:
Returns -1 if the most recently computed fp16 gradients overflowed
for
p
in
group
:
"""
if
set_grads_to_None
:
# TODO: Not most efficient with copy to cpu and sync
p
.
grad
=
None
# only support 2-norm now
else
:
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
if
p
.
grad
is
not
None
:
try
:
p
.
grad
.
detach_
()
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
,
2.0
,
dtype
=
torch
.
float32
))
p
.
grad
.
zero_
()
except
TypeError
as
err
:
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
.
float
(),
2.0
))
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
"""
return
-
1
Compute fp16 grad norm for later clipping(fused with update).
else
:
Internal accumulated in fp32.
return
norm
Also fused in NaN check. Possibly other reduction needed for grad.
def
step
(
self
,
closure
=
None
):
Args:
"""
fp16_grads_flat (tensor): fp16 grad flattened
Not supporting closure.
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
"""
infinity norm.
# First compute norm for all group so we know if there is overflow
grads_groups_flat
=
[]
Returns:
norm_groups
=
[]
Total norm of the current fp16 gradients (viewed as a single vector).
skip
=
False
Returns -1 if the most recently computed fp16 gradients overflowed
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
"""
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
# only support 2-norm now
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
norm
=
float
(
fused_norm
(
fp16_grads_flat
))
skip
=
True
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
return
-
1
if
skip
:
else
:
self
.
_update_scale
(
skip
)
return
norm
return
def
step
(
self
,
closure
=
None
):
# norm is in fact norm*cur_scale
"""
self
.
optimizer
.
step
(
grads
=
[[
g
]
for
g
in
grads_groups_flat
],
Not supporting closure.
output_params
=
[[
p
]
for
p
in
self
.
fp16_groups_flat
],
"""
scale
=
self
.
cur_scale
,
# First compute norm for all group so we know if there is overflow
grad_norms
=
norm_groups
)
grads_groups_flat
=
[]
norm_groups
=
[]
# TODO: we probably don't need this? just to be safe
skip
=
False
for
i
in
range
(
len
(
norm_groups
)):
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
p
.
data
=
q
.
data
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
skip
=
True
self
.
_update_scale
(
False
)
return
if
skip
:
self
.
_update_scale
(
skip
)
def
backward
(
self
,
loss
):
return
"""
:attr:`backward` performs the following steps:
# norm is in fact norm*cur_scale
self
.
optimizer
.
step
(
grads
=
[[
g
]
for
g
in
grads_groups_flat
],
1. fp32_loss = loss.float()
output_params
=
[[
p
]
for
p
in
self
.
fp16_groups_flat
],
2. scaled_loss = fp32_loss*loss_scale
scale
=
self
.
cur_scale
,
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
grad_norms
=
norm_groups
)
"""
scaled_loss
=
(
loss
.
float
())
*
self
.
cur_scale
# TODO: we probably don't need this? just to be safe
scaled_loss
.
backward
()
for
i
in
range
(
len
(
norm_groups
)):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
def
_update_scale
(
self
,
skip
):
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
if
self
.
dynamic_loss_scale
:
p
.
data
=
q
.
data
if
skip
:
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
self
.
_update_scale
(
False
)
print
(
"Using dynamic loss scale of"
,
self
.
cur_scale
)
return
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
def
backward
(
self
,
loss
):
else
:
"""
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
self
.
scale_window
==
0
:
:attr:`backward` performs the following steps:
self
.
cur_scale
*=
self
.
scale_factor
else
:
1. fp32_loss = loss.float()
if
skip
:
2. scaled_loss = fp32_loss*loss_scale
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
print
(
"Using static loss scale of"
,
self
.
cur_scale
)
"""
self
.
cur_iter
+=
1
scaled_loss
=
(
loss
.
float
())
*
self
.
cur_scale
return
scaled_loss
.
backward
()
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def
_update_scale
(
self
,
skip
):
def
_get_state
(
self
):
if
self
.
dynamic_loss_scale
:
return
self
.
optimizer
.
state
if
skip
:
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
def
_set_state
(
self
,
value
):
print
(
"Using dynamic loss scale of"
,
self
.
cur_scale
)
self
.
optimizer
.
state
=
value
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
state
=
property
(
_get_state
,
_set_state
)
else
:
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
self
.
scale_window
==
0
:
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
self
.
cur_scale
*=
self
.
scale_factor
# (for example, to adjust the learning rate)
else
:
def
_get_param_groups
(
self
):
if
skip
:
return
self
.
optimizer
.
param_groups
print
(
"
\n
Grad overflow on iteration"
,
self
.
cur_iter
)
print
(
"Using static loss scale of"
,
self
.
cur_scale
)
def
_set_param_groups
(
self
,
value
):
self
.
cur_iter
+=
1
self
.
optimizer
.
param_groups
=
value
return
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def
_get_state
(
self
):
def
state_dict
(
self
):
return
self
.
optimizer
.
state
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
def
_set_state
(
self
,
value
):
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
self
.
optimizer
.
state
=
value
of the contained Pytorch optimizer.
Example::
state
=
property
(
_get_state
,
_set_state
)
checkpoint = {}
checkpoint['model'] = model.state_dict()
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
checkpoint['optimizer'] = optimizer.state_dict()
# (for example, to adjust the learning rate)
torch.save(checkpoint, "saved.pth")
def
_get_param_groups
(
self
):
"""
return
self
.
optimizer
.
param_groups
state_dict
=
{}
state_dict
[
'dynamic_loss_scale'
]
=
self
.
dynamic_loss_scale
def
_set_param_groups
(
self
,
value
):
state_dict
[
'cur_scale'
]
=
self
.
cur_scale
self
.
optimizer
.
param_groups
=
value
state_dict
[
'cur_iter'
]
=
self
.
cur_iter
if
state_dict
[
'dynamic_loss_scale'
]:
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
state_dict
[
'last_overflow_iter'
]
=
self
.
last_overflow_iter
state_dict
[
'scale_factor'
]
=
self
.
scale_factor
def
state_dict
(
self
):
state_dict
[
'scale_window'
]
=
self
.
scale_window
"""
state_dict
[
'optimizer_state_dict'
]
=
self
.
optimizer
.
state_dict
()
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
state_dict
[
'fp32_groups_flat'
]
=
self
.
fp32_groups_flat
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
return
state_dict
of the contained Pytorch optimizer.
Example::
def
load_state_dict
(
self
,
state_dict
):
checkpoint = {}
"""
checkpoint['model'] = model.state_dict()
Loads a state_dict created by an earlier call to state_dict().
checkpoint['optimizer'] = optimizer.state_dict()
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
torch.save(checkpoint, "saved.pth")
whose parameters in turn came from ``model``, it is expected that the user
"""
will call ``model.load_state_dict()`` before
state_dict
=
{}
``fp16_optimizer_instance.load_state_dict()`` is called.
state_dict
[
'dynamic_loss_scale'
]
=
self
.
dynamic_loss_scale
Example::
state_dict
[
'cur_scale'
]
=
self
.
cur_scale
model = torch.nn.Linear(D_in, D_out).cuda().half()
state_dict
[
'cur_iter'
]
=
self
.
cur_iter
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
if
state_dict
[
'dynamic_loss_scale'
]:
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
state_dict
[
'last_overflow_iter'
]
=
self
.
last_overflow_iter
...
state_dict
[
'scale_factor'
]
=
self
.
scale_factor
checkpoint = torch.load("saved.pth")
state_dict
[
'scale_window'
]
=
self
.
scale_window
model.load_state_dict(checkpoint['model'])
state_dict
[
'optimizer_state_dict'
]
=
self
.
optimizer
.
state_dict
()
optimizer.load_state_dict(checkpoint['optimizer'])
state_dict
[
'fp32_groups_flat'
]
=
self
.
fp32_groups_flat
"""
return
state_dict
# I think it should actually be ok to reload the optimizer before the model.
self
.
dynamic_loss_scale
=
state_dict
[
'dynamic_loss_scale'
]
def
load_state_dict
(
self
,
state_dict
):
self
.
cur_scale
=
state_dict
[
'cur_scale'
]
"""
self
.
cur_iter
=
state_dict
[
'cur_iter'
]
Loads a state_dict created by an earlier call to state_dict().
if
state_dict
[
'dynamic_loss_scale'
]:
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
self
.
last_overflow_iter
=
state_dict
[
'last_overflow_iter'
]
whose parameters in turn came from ``model``, it is expected that the user
self
.
scale_factor
=
state_dict
[
'scale_factor'
]
will call ``model.load_state_dict()`` before
self
.
scale_window
=
state_dict
[
'scale_window'
]
``fp16_optimizer_instance.load_state_dict()`` is called.
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer_state_dict'
])
Example::
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
model = torch.nn.Linear(D_in, D_out).cuda().half()
# The optimizer's hyperparameters and internal buffers are also up to date.
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
# out of date. There are two options.
...
# 1: Refresh the master params from the model's fp16 params.
checkpoint = torch.load("saved.pth")
# This requires less storage but incurs precision loss.
model.load_state_dict(checkpoint['model'])
# 2: Save and restore the fp32 master copies separately.
optimizer.load_state_dict(checkpoint['optimizer'])
# We choose option 2.
"""
#
# I think it should actually be ok to reload the optimizer before the model.
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
self
.
dynamic_loss_scale
=
state_dict
[
'dynamic_loss_scale'
]
# of their associated parameters, because it's possible those buffers might not exist yet in
self
.
cur_scale
=
state_dict
[
'cur_scale'
]
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
self
.
cur_iter
=
state_dict
[
'cur_iter'
]
# constructed in the same way as the one whose state_dict we are loading, the same master params
if
state_dict
[
'dynamic_loss_scale'
]:
# are guaranteed to exist, so we can just copy_() from the saved master params.
self
.
last_overflow_iter
=
state_dict
[
'last_overflow_iter'
]
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
self
.
scale_factor
=
state_dict
[
'scale_factor'
]
current
.
data
.
copy_
(
saved
.
data
)
self
.
scale_window
=
state_dict
[
'scale_window'
]
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer_state_dict'
])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
state_dict
[
'fp32_groups_flat'
]):
current
.
data
.
copy_
(
saved
.
data
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment