Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5b300119
Commit
5b300119
authored
Apr 28, 2020
by
Kexin Yu
Browse files
LAMB: global grad clipping & more flexibility in adaptive lr
parent
1f2aa915
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
24 deletions
+84
-24
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+38
-3
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+5
-2
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+24
-13
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+17
-6
No files found.
apex/optimizers/fused_lamb.py
View file @
5b300119
...
...
@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
...
...
@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
amsgrad
=
False
,
adam_w_mode
=
True
,
grad_averaging
=
True
,
set_grad_none
=
True
,
max_grad_norm
=
1.0
):
max_grad_norm
=
1.0
,
use_nvlamb
=
False
):
if
amsgrad
:
raise
RuntimeError
(
'FusedLAMB does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
...
@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
...
...
@@ -100,6 +103,34 @@ class FusedLAMB(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dytpe
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
g_norm_32
,
g_norm_16
=
0.0
,
0.0
# compute grad norm for two lists
if
len
(
g_all_32
)
>
0
:
g_norm_32
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)[
0
].
item
()
if
len
(
g_all_16
)
>
0
:
g_norm_16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)[
0
].
item
()
# blend two grad norms to get global grad norm
global_grad_norm
=
math
.
sqrt
(
g_norm_32
*
g_norm_32
+
g_norm_16
*
g_norm_16
)
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
...
...
@@ -156,7 +187,9 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
use_nvlamb
)
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
...
...
@@ -170,6 +203,8 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
use_nvlamb
)
return
loss
csrc/amp_C_frontend.cpp
View file @
5b300119
...
...
@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
step_size
);
const
float
step_size
,
at
::
optional
<
bool
>
use_nvlamb_python
);
void
multi_tensor_adam_cuda
(
int
chunk_size
,
...
...
@@ -95,7 +96,9 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
max_grad_norm
);
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
...
...
csrc/multi_tensor_lamb.cu
View file @
5b300119
...
...
@@ -41,8 +41,8 @@ struct LAMBStage1Functor
const
float
epsilon
,
adamMode_t
mode
,
const
float
decay
,
float
*
global_grad_norm
,
float
max_global_grad_norm
)
const
float
global_grad_norm
,
const
float
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -52,7 +52,7 @@ struct LAMBStage1Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
...
...
@@ -150,7 +150,9 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -161,9 +163,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
...
...
@@ -221,12 +229,16 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
max_grad_norm
)
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
using
namespace
at
;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
...
...
@@ -241,9 +253,6 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
...
...
@@ -265,7 +274,7 @@ void multi_tensor_lamb_cuda(
epsilon
,
(
adamMode_t
)
mode
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
DATA_PTR
<
float
>
()
,
global_grad_norm
,
max_grad_norm
);
)
// Compute update norms
...
...
@@ -282,7 +291,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor
<
scalar_t_0
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
DATA_PTR
<
float
>
(),
lr
);
)
lr
,
weight_decay
,
use_nvlamb
);
)
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
5b300119
...
...
@@ -24,7 +24,8 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -35,9 +36,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
...
...
@@ -87,8 +94,11 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
...
...
@@ -101,7 +111,8 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_param_norm
.
DATA_PTR
<
float
>
(),
per_tensor_update_norm
.
DATA_PTR
<
float
>
(),
learning_rate
);
))
learning_rate
,
use_nvlamb
);
))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment