Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
be42aad5
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e9f48a4f8ec89acbee7bd870265ed62e21cd2c58"
Commit
be42aad5
authored
Dec 05, 2018
by
Deyu Fu
Browse files
WIP: improve fused adam
parent
b436213e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
279 additions
and
31 deletions
+279
-31
apex/optimizers/__init__.py
apex/optimizers/__init__.py
+1
-0
apex/optimizers/csrc/fused_adam_cuda.cpp
apex/optimizers/csrc/fused_adam_cuda.cpp
+3
-3
apex/optimizers/csrc/fused_adam_cuda_kernel.cu
apex/optimizers/csrc/fused_adam_cuda_kernel.cu
+21
-9
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+183
-0
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+71
-19
No files found.
apex/optimizers/__init__.py
View file @
be42aad5
from
.fused_adam
import
FusedAdam
from
.fused_adam
import
FusedAdam
from
.fp16_optimizer
import
FP16_Optimizer
apex/optimizers/csrc/fused_adam_cuda.cpp
View file @
be42aad5
#include <torch/extension.h>
#include <torch/extension.h>
// CUDA forward declaration
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
// C++ interface
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
)
{
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
)
CHECK_INPUT
(
p
)
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
m
);
...
@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
...
@@ -20,7 +20,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/optimizers/csrc/fused_adam_cuda_kernel.cu
View file @
be42aad5
...
@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel(
...
@@ -28,7 +28,8 @@ __global__ void adam_cuda_kernel(
const
float
grad_scale
,
const
float
grad_scale
,
const
float
step_size
,
const
float
step_size
,
const
size_t
tsize
,
const
size_t
tsize
,
adamMode_t
mode
)
{
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
...
@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel(
...
@@ -46,7 +47,8 @@ __global__ void adam_cuda_kernel(
denom
=
sqrtf
(
v
[
j
]
+
eps
);
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
denom
=
sqrtf
(
v
[
j
])
+
eps
;
p
[
j
]
=
p
[
j
]
-
(
step_size
*
m
[
j
]
/
denom
);
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
...
@@ -63,7 +65,9 @@ void fused_adam_cuda(
...
@@ -63,7 +65,9 @@ void fused_adam_cuda(
float
eps
,
float
eps
,
float
grad_scale
,
float
grad_scale
,
int
step
,
int
step
,
int
mode
)
{
int
mode
,
int
bias_correction
,
float
decay
)
{
//Get tensor size
//Get tensor size
int
tsize
=
p
.
numel
();
int
tsize
=
p
.
numel
();
...
@@ -72,15 +76,21 @@ void fused_adam_cuda(
...
@@ -72,15 +76,21 @@ void fused_adam_cuda(
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
//Constants
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
float
step_size
=
0
;
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
if
(
bias_correction
==
1
)
{
const
float
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
if
(
g
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
//dispatch is done on the gradient type
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
accscalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
...
@@ -95,7 +105,8 @@ void fused_adam_cuda(
...
@@ -95,7 +105,8 @@ void fused_adam_cuda(
grad_scale
,
grad_scale
,
step_size
,
step_size
,
tsize
,
tsize
,
(
adamMode_t
)
mode
);
(
adamMode_t
)
mode
,
decay
);
}));
}));
}
else
{
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
...
@@ -111,7 +122,8 @@ void fused_adam_cuda(
...
@@ -111,7 +122,8 @@ void fused_adam_cuda(
grad_scale
,
grad_scale
,
step_size
,
step_size
,
tsize
,
tsize
,
(
adamMode_t
)
mode
);
(
adamMode_t
)
mode
,
decay
);
}));
}));
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
...
...
apex/optimizers/fp16_optimizer.py
0 → 100755
View file @
be42aad5
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
class
FP16_Optimizer
(
object
):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Design to be used in the same way but support only fused optimizers in apex.
Refer to apex.fp16_utils documents for more information.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.optimizers.FusedAdam(model.parameters())
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
"""
def
__init__
(
self
,
init_optimizer
,
static_loss_scale
=
1.0
,
dynamic_loss_scale
=
False
,
dynamic_loss_args
=
None
,
verbose
=
True
):
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
if
not
torch
.
cuda
.
is_available
:
raise
SystemError
(
"Cannot use fp16 without CUDA."
)
self
.
optimizer
=
init_optimizer
# param flattened by groups
self
.
fp16_groups
=
[]
self
.
fp16_groups_flat
=
[]
self
.
fp32_groups_flat
=
[]
# loop to deal with groups
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
# push this group to list before modify
self
.
fp16_groups
.
append
(
param_group
[
'params'
])
# init fp16 weight buffer, flattened
self
.
fp16_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
clone
().
detach
()
for
p
in
self
.
fp16_groups
[
i
]]))
# set model fp16 weight to slices of flattened buffer
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
# init master weight, flattened
self
.
fp32_groups_flat
.
append
(
self
.
fp16_groups_flat
[
i
].
clone
().
float
().
detach
())
# modify optimizer of have flat master weight
self
.
fp32_groups_flat
[
i
].
requires_grad
=
True
# keep this in case internal optimizer uses it
param_group
[
'params'
]
=
[
self
.
fp32_groups_flat
[
i
]]
# we may have a way of fusing dynamic scale. Do not support for now
if
dynamic_loss_scale
:
if
dynamic_loss_args
is
not
None
:
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
self
.
dynamic_loss_scale
=
True
self
.
cur_scale
=
2
**
32
self
.
cur_iter
=
0
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
2
self
.
scale_window
=
1000
else
:
self
.
dynamic_loss_scale
=
False
self
.
cur_iter
=
0
self
.
cur_scale
=
static_loss_scale
def
zero_grad
(
self
,
set_grads_to_None
=
True
):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for
group
in
self
.
fp16_groups
:
for
p
in
group
:
if
set_grads_to_None
:
p
.
grad
=
None
else
:
if
p
.
grad
is
not
None
:
p
.
grad
.
detach_
()
p
.
grad
.
zero_
()
def
_compute_grad_norm
(
self
,
fp16_grads_flat
,
norm_type
=
2
):
"""
Compute fp16 grad norm for later clipping(fused with update).
Internal accumulated in fp32.
Also fused in NaN check. Possibly other reduction needed for grad.
Args:
fp16_grads_flat (tensor): fp16 grad flattened
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp16 gradients (viewed as a single vector).
Returns -1 if the most recently computed fp16 gradients overflowed
"""
# TODO: currently using pre-1.0 api, and not most efficient with copy to cpu and sync
norm
=
float
(
torch
.
norm
(
fp16_grads_flat
,
p
=
norm_type
))
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
return
-
1
else
:
return
norm
def
step
(
self
,
closure
=
None
):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat
=
[]
norm_groups
=
[]
skip
=
False
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
grads_groups_flat
.
append
(
_flatten_dense_tensors
([
p
.
grad
for
p
in
group
]))
norm_groups
.
append
(
self
.
_compute_grad_norm
(
grads_groups_flat
[
i
]))
if
norm_groups
[
i
]
==
-
1
:
#TODO: early break
skip
=
True
if
skip
:
self
.
_update_scale
(
skip
)
return
# norm is in fact norm*cur_scale
self
.
optimizer
.
step
(
grads_group
=
[[
g
]
for
g
in
grads_groups_flat
],
output_params_group
=
[[
p
]
for
p
in
self
.
fp16_groups_flat
],
scale
=
self
.
cur_scale
,
grad_norms
=
norm_groups
)
# TODO: we probably don't need this? just to be safe
for
i
in
range
(
len
(
norm_groups
)):
updated_params
=
_unflatten_dense_tensors
(
self
.
fp16_groups_flat
[
i
],
self
.
fp16_groups
[
i
])
for
p
,
q
in
zip
(
self
.
fp16_groups
[
i
],
updated_params
):
p
.
data
=
q
.
data
self
.
_update_scale
(
False
)
return
def
backward
(
self
,
loss
):
"""
:attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss
=
(
loss
.
float
())
*
self
.
cur_scale
scaled_loss
.
backward
()
def
_update_scale
(
self
,
skip
):
if
self
.
dynamic_loss_scale
:
if
skip
:
print
(
"grad overflow on iteration"
,
self
.
cur_iter
)
print
(
"Using dynamic loss scale of"
,
self
.
cur_scale
)
self
.
cur_scale
=
max
(
self
.
cur_scale
/
self
.
scale_factor
,
1
)
self
.
last_overflow_iter
=
self
.
cur_iter
else
:
if
(
self
.
cur_iter
-
self
.
last_overflow_iter
)
%
self
.
scale_window
==
0
:
self
.
cur_scale
*=
self
.
scale_factor
else
:
if
skip
:
print
(
"Grad overflow on iteration"
,
self
.
cur_iter
)
print
(
"Using static loss scale of"
,
self
.
cur_scale
)
self
.
cur_iter
+=
1
return
apex/optimizers/fused_adam.py
View file @
be42aad5
import
math
import
torch
import
torch
import
fused_adam_cuda
import
fused_adam_cuda
class
FusedAdam
(
torch
.
optim
.
Adam
):
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
}
class
FusedAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
``python setup.py install --cuda_ext --cpp_ext``.
...
@@ -20,8 +42,8 @@ class FusedAdam(torch.optim.Adam):
...
@@ -20,8 +42,8 @@ class FusedAdam(torch.optim.Adam):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
second moment estimate as in the original paper. (default: False)
...
@@ -31,24 +53,29 @@ class FusedAdam(torch.optim.Adam):
...
@@ -31,24 +53,29 @@ class FusedAdam(torch.optim.Adam):
https://openreview.net/forum?id=ryQu7f-RZ
https://openreview.net/forum?id=ryQu7f-RZ
"""
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
def
__init__
(
self
,
params
,
weight_decay
=
0
,
amsgrad
=
False
,
eps_inside_sqrt
=
False
):
lr
=
1e-3
,
warmup
=-
1
,
t_total
=-
1
,
schedule
=
'warmup_linear'
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
):
if
amsgrad
:
if
amsgrad
:
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
lr
,
betas
,
eps
,
weight_decay
,
amsgrad
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
):
def
step
(
self
,
closure
=
None
,
grads
_group
=
None
,
output_params
_group
=
None
,
scale
=
1.
,
grad_norms
=
None
):
"""Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
Arguments:
closure (callable, optional): A closure that reevaluates the model
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
by before applying to weights. (default: 1)
...
@@ -56,19 +83,34 @@ class FusedAdam(torch.optim.Adam):
...
@@ -56,19 +83,34 @@ class FusedAdam(torch.optim.Adam):
loss
=
None
loss
=
None
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
if
grads
is
not
None
:
assert
len
(
self
.
param_groups
)
==
1
,
"mixed precision optimizer works for a single group only"
if
grads_group
is
None
:
for
group
in
self
.
param_groups
:
grads_group
=
[
None
]
*
len
(
self
.
param_groups
)
if
output_params_group
is
None
:
output_params_group
=
[
None
]
*
len
(
self
.
param_groups
)
if
grad_norms
is
None
:
grad_norms
=
[
None
]
*
len
(
self
.
param_groups
)
for
group
,
grads
,
output_params
,
grad_norm
in
zip
(
self
.
param_groups
,
grads_group
,
output_params_group
,
grad_norms
):
if
grads
is
None
:
if
grads
is
None
:
grads
=
[
None
]
*
len
(
group
[
'params'
])
grads
=
[
None
]
*
len
(
group
[
'params'
])
if
output_params
is
None
:
if
output_params
is
None
:
output_params
=
[
None
]
*
len
(
group
[
'params'
])
output_params
=
[
None
]
*
len
(
group
[
'params'
])
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads
,
output_params
):
# compute combined scale factor for this group
combined_scale
=
scale
if
group
[
'max_grad_norm'
]
>
0
:
# norm is in fact norm*scale
clip
=
((
grad_norm
/
scale
)
+
1e-6
)
/
group
[
'max_grad_norm'
]
if
clip
>
1
:
combined_scale
=
clip
*
scale
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads
,
output_params
):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if
p
.
grad
is
None
and
grad
is
None
:
if
p
.
grad
is
None
and
grad
is
None
:
continue
continue
if
grad
is
None
:
if
grad
is
None
:
grad
=
p
.
grad
.
data
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
if
grad
.
is_sparse
:
raise
RuntimeError
(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
raise
RuntimeError
(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
...
@@ -85,7 +127,16 @@ class FusedAdam(torch.optim.Adam):
...
@@ -85,7 +127,16 @@ class FusedAdam(torch.optim.Adam):
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
bias_correction
=
0
else
:
lr_scheduled
=
group
[
'lr'
]
bias_correction
=
1
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
fused_adam_cuda
.
adam
(
p
.
data
,
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
out_p
,
...
@@ -96,8 +147,9 @@ class FusedAdam(torch.optim.Adam):
...
@@ -96,8 +147,9 @@ class FusedAdam(torch.optim.Adam):
beta1
,
beta1
,
beta2
,
beta2
,
group
[
'eps'
],
group
[
'eps'
],
scale
,
combined_
scale
,
state
[
'step'
],
state
[
'step'
],
self
.
eps_mode
)
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
return
loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment