Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8abb6908
Unverified
Commit
8abb6908
authored
May 19, 2020
by
Kexin Yu
Committed by
GitHub
May 19, 2020
Browse files
Merge pull request #819 from kexinyu/master
Use global gradient clipping in FusedLAMB & add option for using NVLAMB
parents
3bae8c83
bd6e66df
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
24 deletions
+94
-24
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+42
-3
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+6
-2
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+24
-13
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+22
-6
No files found.
apex/optimizers/fused_lamb.py
View file @
8abb6908
...
...
@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
...
...
@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
amsgrad
=
False
,
adam_w_mode
=
True
,
grad_averaging
=
True
,
set_grad_none
=
True
,
max_grad_norm
=
1.0
):
max_grad_norm
=
1.0
,
use_nvlamb
=
False
):
if
amsgrad
:
raise
RuntimeError
(
'FusedLAMB does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
...
@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
...
...
@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
use_nvlamb
=
use_nvlamb
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
...
...
@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dytpe
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
g_norm_32
,
g_norm_16
=
torch
.
zeros
(
1
,
device
=
'cuda'
),
torch
.
zeros
(
1
,
device
=
'cuda'
)
# compute grad norm for two lists
if
len
(
g_all_32
)
>
0
:
g_norm_32
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)[
0
]
if
len
(
g_all_16
)
>
0
:
g_norm_16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)[
0
]
# blend two grad norms to get global grad norm
global_grad_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[[
g_norm_32
,
g_norm_16
]],
False
)[
0
].
item
()
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
...
...
@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
self
.
use_nvlamb
)
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
...
...
@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
grad_averaging
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
self
.
use_nvlamb
)
return
loss
csrc/amp_C_frontend.cpp
View file @
8abb6908
...
...
@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
step_size
);
const
float
lr
,
const
float
weight_decay
,
at
::
optional
<
bool
>
use_nvlamb_python
);
void
multi_tensor_adam_cuda
(
int
chunk_size
,
...
...
@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
max_grad_norm
);
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
...
...
csrc/multi_tensor_lamb.cu
View file @
8abb6908
...
...
@@ -52,8 +52,8 @@ struct LAMBStage1Functor
const
float
epsilon
,
adamMode_t
mode
,
const
float
decay
,
float
*
global_grad_norm
,
float
max_global_grad_norm
)
const
float
global_grad_norm
,
const
float
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -63,7 +63,7 @@ struct LAMBStage1Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
...
...
@@ -239,7 +239,9 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -250,9 +252,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
...
...
@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
max_grad_norm
)
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
using
namespace
at
;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
...
...
@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
...
...
@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon
,
(
adamMode_t
)
mode
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
DATA_PTR
<
float
>
()
,
global_grad_norm
,
max_grad_norm
);
)
// Compute update norms
...
...
@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor
<
scalar_t_0
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
DATA_PTR
<
float
>
(),
lr
);
)
lr
,
weight_decay
,
use_nvlamb
);
)
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
8abb6908
...
...
@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512
#define ILP 4
using
MATH_T
=
float
;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
,
typename
UPD_T
>
...
...
@@ -24,7 +26,9 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
...
...
@@ -35,9 +39,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
...
...
@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
lr
,
const
float
weight_decay
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
...
...
@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_param_norm
.
DATA_PTR
<
float
>
(),
per_tensor_update_norm
.
DATA_PTR
<
float
>
(),
learning_rate
);
))
lr
,
weight_decay
,
use_nvlamb
);
))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment