Commit 5b300119 authored by Kexin Yu's avatar Kexin Yu
Browse files

LAMB: global grad clipping & more flexibility in adaptive lr

parent 1f2aa915
......@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
......@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
max_grad_norm=1.0, use_nvlamb=False):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
......@@ -100,6 +103,34 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
......@@ -156,7 +187,9 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
use_nvlamb)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
......@@ -170,6 +203,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
use_nvlamb)
return loss
......@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);
const float step_size,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda(
int chunk_size,
......@@ -95,7 +96,9 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm);
const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
......@@ -41,8 +41,8 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
float* global_grad_norm,
float max_global_grad_norm)
const float global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -52,7 +52,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
......@@ -150,7 +150,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -161,9 +163,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
......@@ -221,12 +229,16 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm)
const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -241,9 +253,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
......@@ -265,7 +274,7 @@ void multi_tensor_lamb_cuda(
epsilon,
(adamMode_t) mode,
weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(),
global_grad_norm,
max_grad_norm); )
// Compute update norms
......@@ -282,7 +291,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr); )
lr,
weight_decay,
use_nvlamb); )
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -24,7 +24,8 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -35,9 +36,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
......@@ -87,8 +94,11 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
at::optional<bool> use_nvlamb_python)
{
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
......@@ -101,7 +111,8 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(),
learning_rate); ))
learning_rate,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment