Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c19ee275
Commit
c19ee275
authored
Dec 18, 2019
by
Kexin Yu
Committed by
mcarilli
Dec 18, 2019
Browse files
updated apex.contrib.optimizers.FP16_Optimizer and FusedSGD (#657)
parent
4ad9b3bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
255 additions
and
78 deletions
+255
-78
apex/contrib/optimizers/fp16_optimizer.py
apex/contrib/optimizers/fp16_optimizer.py
+44
-78
apex/contrib/optimizers/fused_sgd.py
apex/contrib/optimizers/fused_sgd.py
+211
-0
No files found.
apex/contrib/optimizers/fp16_optimizer.py
View file @
c19ee275
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Designed only to wrap apex.optimizers.FusedAdam.
Designed only to wrap apex.
contrib.
optimizers.FusedAdam
, FusedSGD
.
Refer to apex.fp16_utils documents for more information.
Refer to apex.fp16_utils documents for more information.
Example::
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.optimizers.FusedAdam(model.parameters())
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
...
# loss.backward() becomes:
# loss.backward() becomes:
optimizer.backward(loss)
optimizer.backward(loss)
...
...
Example with dynamic loss scaling::
Example with dynamic loss scaling::
...
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# optional arg to control dynamic loss scaling behavior
...
@@ -35,41 +29,36 @@ class FP16_Optimizer(object):
...
@@ -35,41 +29,36 @@ class FP16_Optimizer(object):
dynamic_loss_args
=
None
,
dynamic_loss_args
=
None
,
verbose
=
True
):
verbose
=
True
):
print
(
"
\n
fp16_optimizer is designed to only work with apex.optimizers
, and will be removed in future
"
)
print
(
"
\n
This
fp16_optimizer is designed to only work with apex.
contrib.
optimizers
.*
"
)
print
(
"To update, use updated optimizers with AMP."
)
print
(
"To update, use updated optimizers with AMP."
)
# The fused optimizer does all the work. We need this layer for two reason:
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
# 2. keep common stuff here in case we need to add new fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if
not
torch
.
cuda
.
is_available
:
if
not
torch
.
cuda
.
is_available
:
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
self
.
optimizer
=
init_optimizer
self
.
optimizer
=
init_optimizer
# param flattened by group
s
self
.
fp16_groups
=
[]
# model param
s
self
.
fp
16
_groups
=
[]
self
.
fp
32
_groups
=
[]
# master weights
self
.
fp16_groups_flat
=
[]
self
.
fp32_groups_flat
=
[]
# iterate over param_groups
for
param_group
in
self
.
optimizer
.
param_groups
:
# loop to deal with groups
fp16_group
=
[]
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param
_group
s
):
fp32
_group
=
[]
# push this group to list before modify
for
p
in
param_group
[
'params'
]:
self
.
fp16_group
s
.
append
(
p
aram_group
[
'params'
]
)
fp16_group
.
append
(
p
)
# init fp16 weight buffer,
flat
tened
fp32_group
.
append
(
p
.
clone
().
fl
o
at
().
detach
())
self
.
fp16_groups
_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_group
s
[
i
]])
)
self
.
fp16_groups
.
append
(
fp16_group
)
# set model fp16 weight to slices of flattened buffer
self
.
fp32_groups
.
append
(
fp32_group
)
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16
_group
s
[
i
])
param_group
[
'params'
]
=
fp32
_group
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
if
multi_tensor_applier
.
available
:
# init master weight, flattened
import
amp_C
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
()
)
self
.
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]
)
# modify optimizer of have flat master weight
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
else
:
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
raise
RuntimeError
(
'FP16_Optimizer requires cuda extensions'
)
# we may have a way of fusing dynamic scale. Do not support for now
# we may have a way of fusing dynamic scale. Do not support for now
if
dynamic_loss_scale
:
if
dynamic_loss_scale
:
...
@@ -102,70 +91,47 @@ class FP16_Optimizer(object):
...
@@ -102,70 +91,47 @@ class FP16_Optimizer(object):
p
.
grad
.
detach_
()
p
.
grad
.
detach_
()
p
.
grad
.
zero_
()
p
.
grad
.
zero_
()
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
"""
Compute fp16 grad norm for later clipping(fused with update).
Internal accumulated in fp32.
Also fused in NaN check. Possibly other reduction needed for grad.
Args:
fp16_grads_flat (tensor): fp16 grad flattened
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
"""
# TODO: Not most efficient with copy to cpu and sync
# only support 2-norm now
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
try
:
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
,
2.0
,
dtype
=
torch
.
float32
))
except
TypeError
as
err
:
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
.
float
(),
2.0
))
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
return
-
1
else
:
return
norm
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
"""
"""
Not supporting closure.
Not supporting closure.
"""
"""
# First compute norm for all group so we know if there is overflow
fp16_grads
=
[]
grads_groups_flat
=
[]
norm_groups
=
[]
norm_groups
=
[]
skip
=
False
skip
=
False
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
for
group
in
self
.
fp16_groups
:
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
fp16_grad
=
[]
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
for
i
,
p
in
enumerate
(
group
):
skip
=
True
fp16_grad
.
append
(
p
.
grad
)
fp16_grads
.
append
(
fp16_grad
)
# nan check
self
.
overflow_buf
.
zero_
()
for
fp16_grad
in
fp16_grads
:
if
len
(
fp16_grad
)
>
0
:
norm
,
norm_per_tensor
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
overflow_buf
,
[
fp16_grad
],
True
)
norm_groups
.
append
(
norm
)
if
self
.
overflow_buf
.
item
()
!=
0
:
skip
=
True
if
skip
:
if
skip
:
self
.
_update_scale
(
skip
)
self
.
_update_scale
(
skip
)
return
return
# norm is in fact norm*cur_scale
# norm is in fact norm*cur_scale
self
.
optimizer
.
step
(
grads
=
[[
g
]
for
g
in
grads_groups_flat
]
,
self
.
optimizer
.
step
(
grads
=
fp16_grads
,
output_params
=
[[
p
]
for
p
in
self
.
fp16_groups
_flat
]
,
output_params
=
self
.
fp16_groups
,
scale
=
self
.
cur_scale
,
scale
=
self
.
cur_scale
,
grad_norms
=
norm_groups
)
grad_norms
=
norm_groups
)
# TODO: we probably don't need this? just to be safe
for
i
in
range
(
len
(
norm_groups
)):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
self
.
_update_scale
(
False
)
self
.
_update_scale
(
False
)
return
return
def
backward
(
self
,
loss
):
def
backward
(
self
,
loss
):
"""
"""
:attr:`backward` performs the following steps:
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
...
...
apex/contrib/optimizers/fused_sgd.py
0 → 100644
View file @
c19ee275
import
types
import
torch
from
torch.optim.optimizer
import
Optimizer
,
required
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FusedSGD
(
Optimizer
):
r
"""Implements stochastic gradient descent (optionally with momentum).
This version of fused SGD implements 2 fusions.
* Fusion of the SGD update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.
:class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
model = ...
model.half()
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
# wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
optimizer.zero_grad()
...
optimizer.backward(loss)
optmizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def
__init__
(
self
,
params
,
lr
=
required
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
wd_after_momentum
=
False
,
materialize_master_grads
=
True
):
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
raise
ValueError
(
"Invalid momentum value: {}"
.
format
(
momentum
))
if
weight_decay
<
0.0
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
momentum
=
momentum
,
dampening
=
dampening
,
weight_decay
=
weight_decay
,
nesterov
=
nesterov
)
if
nesterov
and
(
momentum
<=
0
or
dampening
!=
0
):
raise
ValueError
(
"Nesterov momentum requires a momentum and zero dampening"
)
super
(
FusedSGD
,
self
).
__init__
(
params
,
defaults
)
self
.
wd_after_momentum
=
wd_after_momentum
if
multi_tensor_applier
.
available
:
import
amp_C
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_sgd
=
amp_C
.
multi_tensor_sgd
else
:
raise
RuntimeError
(
'apex.contrib.optimizers.FusedSGD requires cuda extensions'
)
def
__setstate__
(
self
,
state
):
super
(
FusedSGD
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'nesterov'
,
False
)
def
get_momentums
(
self
,
params
):
momentums
=
[]
first_run
=
True
for
p
in
params
:
param_state
=
self
.
state
[
p
]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if
'momentum_buffer'
not
in
param_state
:
first_run
=
True
buf
=
param_state
[
'momentum_buffer'
]
=
torch
.
zeros_like
(
p
.
data
)
momentums
.
append
(
buf
)
else
:
first_run
=
False
momentums
.
append
(
param_state
[
'momentum_buffer'
])
return
momentums
,
first_run
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output_params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
if
hasattr
(
self
,
"_amp_stash"
):
raise
RuntimeError
(
'apex.contrib.optimizers.FusedSGD should not be used with AMP.'
)
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
if
grads
is
None
:
raise
RuntimeError
(
'apex.contrib.optimizers.FusedSGD must be wrapped
\
with apex.contrib.optimizers.FP16_Optimizer
\
which provides grads.'
)
# backward compatibility
# assuming a list/generator of parameter means single group
elif
isinstance
(
grads
,
types
.
GeneratorType
):
grads_group
=
[
grads
]
elif
type
(
grads
[
0
])
!=
list
:
grads_group
=
[
grads
]
else
:
grads_group
=
grads
if
output_params
is
None
:
raise
RuntimeError
(
'apex.contrib.optimizers.FusedSGD must be wrapped
\
with apex.contrib.optimizers.FP16_Optimizer
\
which provides output_params.'
)
elif
isinstance
(
output_params
,
types
.
GeneratorType
):
output_params_group
=
[
output_params
]
elif
type
(
output_params
[
0
])
!=
list
:
output_params_group
=
[
output_params
]
else
:
output_params_group
=
output_params
for
group
,
grads_this_group
,
output_params_this_group
in
zip
(
self
.
param_groups
,
grads_group
,
output_params_group
):
if
grads_this_group
is
None
or
output_params_this_group
is
None
:
raise
RuntimeError
(
'apex.contrib.optimizers.FusedSGD only works
\
when all parameters require grad.'
)
weight_decay
=
group
[
'weight_decay'
]
momentum
=
group
[
'momentum'
]
dampening
=
group
[
'dampening'
]
nesterov
=
group
[
'nesterov'
]
lr
=
group
[
'lr'
]
first_runs
=
[
True
,
True
]
# output_params_this_group: original weights (either fp16 or fp32)
# group['params']: master weights (fp32)
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# fp32, fp32, fp32, No
fp32_grads
=
[
g
for
(
p
,
g
)
in
zip
(
output_params_this_group
,
grads_this_group
)
if
p
.
dtype
==
torch
.
float32
]
fp32_params
=
[
p2
for
(
p1
,
p2
)
in
zip
(
output_params_this_group
,
group
[
'params'
])
if
p1
.
dtype
==
torch
.
float32
]
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
fp32_set
=
[
fp32_grads
,
fp32_params
,
fp32_momentums
]
# fp16, fp32, fp32, Yes
fp16_grads
=
[
g
for
(
p
,
g
)
in
zip
(
output_params_this_group
,
grads_this_group
)
if
p
.
dtype
==
torch
.
float16
]
fp32_from_fp16_params
=
[
p2
for
(
p1
,
p2
)
in
zip
(
output_params_this_group
,
group
[
'params'
])
if
p1
.
dtype
==
torch
.
float16
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp16_params
=
[
p1
for
(
p1
,
p2
)
in
zip
(
output_params_this_group
,
group
[
'params'
])
if
p1
.
dtype
==
torch
.
float16
]
fp16_set
=
[
fp16_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_params
]
launch_sets
=
[
fp16_set
,
fp32_set
]
for
launch_set
,
first_run
in
zip
(
launch_sets
,
first_runs
):
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
1
])
assert
len
(
launch_set
[
0
])
==
len
(
launch_set
[
2
])
if
len
(
launch_set
[
0
])
>
0
:
multi_tensor_applier
(
self
.
multi_tensor_sgd
,
self
.
_dummy_overflow_buf
,
launch_set
,
weight_decay
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
self
.
wd_after_momentum
,
1.0
/
scale
)
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment