Commit 5b300119 authored by Kexin Yu's avatar Kexin Yu
Browse files

LAMB: global grad clipping & more flexibility in adaptive lr

parent 1f2aa915
...@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True) method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0) (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
...@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True, amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True, grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0): max_grad_norm=1.0, use_nvlamb=False):
if amsgrad: if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
...@@ -100,6 +103,34 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -100,6 +103,34 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
...@@ -156,7 +187,9 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -156,7 +187,9 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm,
use_nvlamb)
if(len(g_32) > 0): if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf, self._dummy_overflow_buf,
...@@ -170,6 +203,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -170,6 +203,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm,
use_nvlamb)
return loss return loss
...@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float step_size); const float step_size,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda( void multi_tensor_adam_cuda(
int chunk_size, int chunk_size,
...@@ -95,7 +96,9 @@ void multi_tensor_lamb_cuda( ...@@ -95,7 +96,9 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float max_grad_norm); const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
...@@ -41,8 +41,8 @@ struct LAMBStage1Functor ...@@ -41,8 +41,8 @@ struct LAMBStage1Functor
const float epsilon, const float epsilon,
adamMode_t mode, adamMode_t mode,
const float decay, const float decay,
float* global_grad_norm, const float global_grad_norm,
float max_global_grad_norm) const float max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -52,7 +52,7 @@ struct LAMBStage1Functor ...@@ -52,7 +52,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f; float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
...@@ -150,7 +150,9 @@ struct LAMBStage2Functor ...@@ -150,7 +150,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl, TensorListMetadata<2>& tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate) const float learning_rate,
const float decay,
bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -161,9 +163,15 @@ struct LAMBStage2Functor ...@@ -161,9 +163,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num]; MATH_T ratio = learning_rate;
float update_norm = per_tensor_update_norm[tensor_num]; // nvlamb: apply adaptive learning rate to all parameters
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc]; T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
...@@ -221,12 +229,16 @@ void multi_tensor_lamb_cuda( ...@@ -221,12 +229,16 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float max_grad_norm) const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{ {
using namespace at; using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this // Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type // So we assume every tensor are all in the same type
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -241,9 +253,6 @@ void multi_tensor_lamb_cuda( ...@@ -241,9 +253,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm // Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
...@@ -265,7 +274,7 @@ void multi_tensor_lamb_cuda( ...@@ -265,7 +274,7 @@ void multi_tensor_lamb_cuda(
epsilon, epsilon,
(adamMode_t) mode, (adamMode_t) mode,
weight_decay, weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(), global_grad_norm,
max_grad_norm); ) max_grad_norm); )
// Compute update norms // Compute update norms
...@@ -282,7 +291,9 @@ void multi_tensor_lamb_cuda( ...@@ -282,7 +291,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor<scalar_t_0>(), LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(), std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(), std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr); ) lr,
weight_decay,
use_nvlamb); )
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -24,7 +24,8 @@ struct LAMBStage2Functor ...@@ -24,7 +24,8 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl, TensorListMetadata<2>& tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate) const float learning_rate,
bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -35,9 +36,15 @@ struct LAMBStage2Functor ...@@ -35,9 +36,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num]; MATH_T ratio = learning_rate;
float update_norm = per_tensor_update_norm[tensor_num]; // nvlamb: apply adaptive learning rate to all parameters
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl.addresses[0][tensor_loc]; T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
...@@ -87,8 +94,11 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -87,8 +94,11 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate) const float learning_rate,
at::optional<bool> use_nvlamb_python)
{ {
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
...@@ -101,7 +111,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -101,7 +111,8 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(), LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(), per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(), per_tensor_update_norm.DATA_PTR<float>(),
learning_rate); )) learning_rate,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment