Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3bae8c83
Unverified
Commit
3bae8c83
authored
May 14, 2020
by
Andrew Tulloch
Committed by
GitHub
May 14, 2020
Browse files
Add FusedAdagrad (#822)
parent
9165b27f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
351 additions
and
0 deletions
+351
-0
apex/optimizers/__init__.py
apex/optimizers/__init__.py
+1
-0
apex/optimizers/fused_adagrad.py
apex/optimizers/fused_adagrad.py
+122
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+13
-0
csrc/multi_tensor_adagrad.cu
csrc/multi_tensor_adagrad.cu
+100
-0
setup.py
setup.py
+1
-0
tests/L0/run_optimizers/test_adagrad.py
tests/L0/run_optimizers/test_adagrad.py
+114
-0
No files found.
apex/optimizers/__init__.py
View file @
3bae8c83
...
...
@@ -2,3 +2,4 @@ from .fused_sgd import FusedSGD
from
.fused_adam
import
FusedAdam
from
.fused_novograd
import
FusedNovoGrad
from
.fused_lamb
import
FusedLAMB
from
.fused_adagrad
import
FusedAdagrad
\ No newline at end of file
apex/optimizers/fused_adagrad.py
0 → 100644
View file @
3bae8c83
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FusedAdagrad
(
torch
.
optim
.
Optimizer
):
"""Implements Adagrad algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused Adagrad implements 2 fusions.
* Fusion of the Adagrad update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedAdagrad` may be used with or without Amp. If you wish to use :class:`FusedAdagrad` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
It has been proposed in `Adaptive Subgradient Methods for Online Learning
and Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-10)
adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay (also known as AdamW) (default: False)
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization: http://jmlr.org/papers/v12/duchi11a.html
"""
def
__init__
(
self
,
params
,
lr
=
1e-2
,
eps
=
1e-10
,
weight_decay
=
0.
,
set_grad_none
=
True
,
adagrad_w_mode
=
False
):
defaults
=
dict
(
lr
=
lr
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
FusedAdagrad
,
self
).
__init__
(
params
,
defaults
)
self
.
adagrad_w_mode
=
1
if
adagrad_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
import
amp_C
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_adagrad
=
amp_C
.
multi_tensor_adagrad
else
:
raise
RuntimeError
(
'apex.optimizers.FusedAdagrad requires cuda extensions'
)
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
else
:
super
(
FusedAdagrad
,
self
).
zero_grad
()
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
# create lists for multi-tensor apply
g_16
,
p_16
,
h_16
=
[],
[],
[]
g_32
,
p_32
,
h_32
=
[],
[],
[]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
.
data
.
is_sparse
:
raise
RuntimeError
(
'FusedAdagrad does not support sparse gradients'
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
# Exponential moving average of gradient values
state
[
'sum'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
h_16
.
append
(
state
[
'sum'
])
elif
p
.
dtype
==
torch
.
float32
:
g_32
.
append
(
p
.
grad
.
data
)
p_32
.
append
(
p
.
data
)
h_32
.
append
(
state
[
'sum'
])
else
:
raise
RuntimeError
(
'FusedAdagrad only support fp16 and fp32.'
)
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adagrad
,
self
.
_dummy_overflow_buf
,
[
g_16
,
p_16
,
h_16
],
group
[
'lr'
],
group
[
'eps'
],
self
.
adagrad_w_mode
,
group
[
'weight_decay'
])
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adagrad
,
self
.
_dummy_overflow_buf
,
[
g_32
,
p_32
,
h_32
],
group
[
'lr'
],
group
[
'eps'
],
self
.
adagrad_w_mode
,
group
[
'weight_decay'
])
return
loss
\ No newline at end of file
csrc/amp_C_frontend.cpp
View file @
3bae8c83
...
...
@@ -66,6 +66,17 @@ void multi_tensor_adam_cuda(
const
int
bias_correction
,
const
float
weight_decay
);
void
multi_tensor_adagrad_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
epsilon
,
const
int
mode
,
const
float
weight_decay
);
void
multi_tensor_novograd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -112,6 +123,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Completes application of gradient to parameters for LAMB optimizer"
);
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_adagrad"
,
&
multi_tensor_adagrad_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_novograd"
,
&
multi_tensor_novograd_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_lamb"
,
&
multi_tensor_lamb_cuda
,
...
...
csrc/multi_tensor_adagrad.cu
0 → 100644
View file @
3bae8c83
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 1024
#define ILP 4
typedef
enum
{
ADAGRAD_MODE_0
=
0
,
// L2 regularization mode.
ADAGRAD_MODE_1
=
1
,
// AdamW-style weight decay.
}
adagradMode_t
;
using
MATH_T
=
float
;
template
<
typename
T
>
struct
AdagradFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
const
float
epsilon
,
const
float
lr
,
adagradMode_t
mode
,
const
float
weight_decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
h
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
h
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_h
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
r_p
[
ii
]
=
p
[
i
];
r_h
[
ii
]
=
h
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_h
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
ADAGRAD_MODE_0
)
{
// L2
r_g
[
ii
]
=
r_g
[
ii
]
+
weight_decay
*
r_p
[
ii
];
r_h
[
ii
]
=
r_h
[
ii
]
+
r_g
[
ii
]
*
r_g
[
ii
];
r_p
[
ii
]
=
r_p
[
ii
]
-
lr
*
(
r_g
[
ii
]
/
(
sqrtf
(
r_h
[
ii
])
+
epsilon
));
}
else
{
// AdamW-style
r_h
[
ii
]
=
r_h
[
ii
]
+
r_g
[
ii
]
*
r_g
[
ii
];
r_p
[
ii
]
=
r_p
[
ii
]
-
lr
*
(
r_g
[
ii
]
/
(
sqrtf
(
r_h
[
ii
])
+
epsilon
)
+
weight_decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
h
[
i
]
=
r_h
[
ii
];
}
}
}
}
};
void
multi_tensor_adagrad_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
epsilon
,
const
int
mode
,
const
float
weight_decay
)
{
using
namespace
at
;
// Assume single type across p,g,h now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adagrad"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdagradFunctor
<
scalar_t_0
>
(),
epsilon
,
lr
,
(
adagradMode_t
)
mode
,
weight_decay
);)
AT_CUDA_CHECK
(
cudaGetLastError
());
}
setup.py
View file @
3bae8c83
...
...
@@ -115,6 +115,7 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_lamb_stage_1.cu'
,
'csrc/multi_tensor_lamb_stage_2.cu'
,
'csrc/multi_tensor_adam.cu'
,
'csrc/multi_tensor_adagrad.cu'
,
'csrc/multi_tensor_novograd.cu'
,
'csrc/multi_tensor_lamb.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
...
...
tests/L0/run_optimizers/test_adagrad.py
0 → 100644
View file @
3bae8c83
import
unittest
import
apex
import
torch
class
TestFusedAdagrad
(
unittest
.
TestCase
):
def
setUp
(
self
,
max_abs_diff
=
1e-6
,
max_rel_diff
=
1
,
iters
=
7
):
self
.
max_abs_diff
=
max_abs_diff
self
.
max_rel_diff
=
max_rel_diff
self
.
iters
=
iters
torch
.
cuda
.
manual_seed
(
9876
)
def
tearDown
(
self
):
pass
def
gen_param_optim
(
self
,
tensors
,
adagrad_option
):
ref_param
=
[]
tst_param
=
[]
for
tensor
in
tensors
:
ref_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
tst_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
ref_optim
=
torch
.
optim
.
Adagrad
(
ref_param
,
**
adagrad_option
)
tst_optim
=
apex
.
optimizers
.
FusedAdagrad
(
tst_param
,
**
adagrad_option
)
return
(
ref_param
,
tst_param
,
ref_optim
,
tst_optim
)
def
gen_grad
(
self
,
ref_param
,
tst_param
):
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
p_ref
.
grad
=
torch
.
rand_like
(
p_ref
)
p_tst
.
grad
=
p_ref
.
grad
def
gen_mixed_grad
(
self
,
ref_param
,
tst_param
,
scale
=
1.0
):
half_grads
=
[]
for
p_ref
,
_
in
zip
(
ref_param
,
tst_param
):
half_grads
.
append
(
torch
.
rand_like
(
p_ref
).
half
())
p_ref
.
grad
=
half_grads
[
-
1
].
float
()
/
scale
return
half_grads
def
get_max_diff
(
self
,
ref_param
,
tst_param
):
max_abs_diff
=
max_rel_diff
=
0
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
max_abs_diff_p
=
(
p_ref
-
p_tst
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
)
/
p_ref
).
abs
().
max
().
item
()
if
max_abs_diff_p
>
max_abs_diff
:
max_abs_diff
=
max_abs_diff_p
if
max_rel_diff_p
>
max_rel_diff
:
max_rel_diff
=
max_rel_diff_p
return
max_abs_diff
,
max_rel_diff
def
gen_single_type_test
(
self
,
param_type
=
torch
.
float
):
nelem
=
278011
adagrad_option
=
{
"lr"
:
5e-4
,
"eps"
:
1e-08
,
"weight_decay"
:
1.0e-5
}
tensor
=
torch
.
rand
(
nelem
,
dtype
=
param_type
,
device
=
"cuda"
)
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
self
.
gen_param_optim
(
[
tensor
],
adagrad_option
)
for
_
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_float
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float
)
@
unittest
.
skip
(
"PyTorch optimizer is not numerically correct for fp16"
)
def
test_half
(
self
):
self
.
gen_single_type_test
(
param_type
=
torch
.
float16
)
def
test_multi_params
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
adagrad_option
=
{
"lr"
:
5e-4
,
"eps"
:
1e-08
,
"weight_decay"
:
0
}
tensors
=
[]
for
size
in
sizes
:
tensors
.
append
(
torch
.
rand
(
size
,
dtype
=
torch
.
float
,
device
=
"cuda"
))
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
self
.
gen_param_optim
(
tensors
,
adagrad_option
)
for
_
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_adagrad_option
(
self
):
nelem
=
1
adagrad_option
=
{
"lr"
:
0.01
,
"eps"
:
3e-06
,
"weight_decay"
:
0
}
tensor
=
torch
.
rand
(
nelem
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
self
.
gen_param_optim
(
[
tensor
],
adagrad_option
)
for
_
in
range
(
self
.
iters
):
self
.
gen_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
()
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment