Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8abb6908
Unverified
Commit
8abb6908
authored
May 19, 2020
by
Kexin Yu
Committed by
GitHub
May 19, 2020
Browse files
Merge pull request #819 from kexinyu/master
Use global gradient clipping in FusedLAMB & add option for using NVLAMB
parents
3bae8c83
bd6e66df
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
94 additions
and
24 deletions
+94
-24
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+42
-3
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+6
-2
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+24
-13
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+22
-6
No files found.
apex/optimizers/fused_lamb.py
View file @
8abb6908
...
@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
https://arxiv.org/abs/1904.00962
...
@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
amsgrad
=
False
,
adam_w_mode
=
True
,
amsgrad
=
False
,
adam_w_mode
=
True
,
grad_averaging
=
True
,
set_grad_none
=
True
,
grad_averaging
=
True
,
set_grad_none
=
True
,
max_grad_norm
=
1.0
):
max_grad_norm
=
1.0
,
use_nvlamb
=
False
):
if
amsgrad
:
if
amsgrad
:
raise
RuntimeError
(
'FusedLAMB does not support the AMSGrad variant.'
)
raise
RuntimeError
(
'FusedLAMB does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
...
@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
self
.
use_nvlamb
=
use_nvlamb
def
zero_grad
(
self
):
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
if
self
.
set_grad_none
:
...
@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
# create separate grad lists for fp32 and fp16 params
g_all_32
,
g_all_16
=
[],
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
dtype
==
torch
.
float32
:
g_all_32
.
append
(
p
.
grad
.
data
)
elif
p
.
dytpe
==
torch
.
float16
:
g_all_16
.
append
(
p
.
grad
.
data
)
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
g_norm_32
,
g_norm_16
=
torch
.
zeros
(
1
,
device
=
'cuda'
),
torch
.
zeros
(
1
,
device
=
'cuda'
)
# compute grad norm for two lists
if
len
(
g_all_32
)
>
0
:
g_norm_32
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_32
],
False
)[
0
]
if
len
(
g_all_16
)
>
0
:
g_norm_16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
g_all_16
],
False
)[
0
]
# blend two grad norms to get global grad norm
global_grad_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[[
g_norm_32
,
g_norm_16
]],
False
)[
0
].
item
()
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
...
@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
group
[
'weight_decay'
],
grad_averaging
,
grad_averaging
,
self
.
adam_w_mode
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
self
.
use_nvlamb
)
if
(
len
(
g_32
)
>
0
):
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
self
.
_dummy_overflow_buf
,
...
@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
group
[
'weight_decay'
],
group
[
'weight_decay'
],
grad_averaging
,
grad_averaging
,
self
.
adam_w_mode
,
self
.
adam_w_mode
,
group
[
'max_grad_norm'
])
global_grad_norm
,
max_grad_norm
,
self
.
use_nvlamb
)
return
loss
return
loss
csrc/amp_C_frontend.cpp
View file @
8abb6908
...
@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
step_size
);
const
float
lr
,
const
float
weight_decay
,
at
::
optional
<
bool
>
use_nvlamb_python
);
void
multi_tensor_adam_cuda
(
void
multi_tensor_adam_cuda
(
int
chunk_size
,
int
chunk_size
,
...
@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
...
@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
grad_averaging
,
const
int
mode
,
const
int
mode
,
const
float
max_grad_norm
);
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
...
...
csrc/multi_tensor_lamb.cu
View file @
8abb6908
...
@@ -52,8 +52,8 @@ struct LAMBStage1Functor
...
@@ -52,8 +52,8 @@ struct LAMBStage1Functor
const
float
epsilon
,
const
float
epsilon
,
adamMode_t
mode
,
adamMode_t
mode
,
const
float
decay
,
const
float
decay
,
float
*
global_grad_norm
,
const
float
global_grad_norm
,
float
max_global_grad_norm
)
const
float
max_global_grad_norm
)
{
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
...
@@ -63,7 +63,7 @@ struct LAMBStage1Functor
...
@@ -63,7 +63,7 @@ struct LAMBStage1Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
...
@@ -239,7 +239,9 @@ struct LAMBStage2Functor
...
@@ -239,7 +239,9 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
...
@@ -250,9 +252,15 @@ struct LAMBStage2Functor
...
@@ -250,9 +252,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
// nvlamb: apply adaptive learning rate to all parameters
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
...
@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
...
@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
const
float
weight_decay
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
grad_averaging
,
const
int
mode
,
const
int
mode
,
const
float
max_grad_norm
)
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
{
using
namespace
at
;
using
namespace
at
;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// So we assume every tensor are all in the same type
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
// Handle bias correction mode
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
if
(
bias_correction
==
1
)
{
...
@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
...
@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
...
@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
...
@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon
,
epsilon
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
weight_decay
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
DATA_PTR
<
float
>
()
,
global_grad_norm
,
max_grad_norm
);
)
max_grad_norm
);
)
// Compute update norms
// Compute update norms
...
@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
...
@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor
<
scalar_t_0
>
(),
LAMBStage2Functor
<
scalar_t_0
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
DATA_PTR
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
DATA_PTR
<
float
>
(),
lr
);
)
lr
,
weight_decay
,
use_nvlamb
);
)
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
8abb6908
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
using
MATH_T
=
float
;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
// It computes new parameter value.
template
<
typename
T
,
typename
UPD_T
>
template
<
typename
T
,
typename
UPD_T
>
...
@@ -24,7 +26,9 @@ struct LAMBStage2Functor
...
@@ -24,7 +26,9 @@ struct LAMBStage2Functor
TensorListMetadata
<
2
>&
tl
,
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
learning_rate
,
const
float
decay
,
bool
use_nvlamb
)
{
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
...
@@ -35,9 +39,15 @@ struct LAMBStage2Functor
...
@@ -35,9 +39,15 @@ struct LAMBStage2Functor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
// nvlamb: apply adaptive learning rate to all parameters
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
0.0
))
{
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
...
@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
)
const
float
lr
,
const
float
weight_decay
,
at
::
optional
<
bool
>
use_nvlamb_python
)
{
{
bool
use_nvlamb
=
use_nvlamb_python
.
has_value
()
?
use_nvlamb_python
.
value
()
:
false
;
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
...
@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_param_norm
.
DATA_PTR
<
float
>
(),
per_tensor_param_norm
.
DATA_PTR
<
float
>
(),
per_tensor_update_norm
.
DATA_PTR
<
float
>
(),
per_tensor_update_norm
.
DATA_PTR
<
float
>
(),
learning_rate
);
))
lr
,
weight_decay
,
use_nvlamb
);
))
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment