Unverified Commit 8abb6908 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Merge pull request #819 from kexinyu/master

Use global gradient clipping in FusedLAMB & add option for using NVLAMB
parents 3bae8c83 bd6e66df
...@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True) method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0) (default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
...@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True, amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True, grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0): max_grad_norm=1.0, use_nvlamb=False):
if amsgrad: if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb self.multi_tensor_lamb = amp_C.multi_tensor_lamb
...@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self): def zero_grad(self):
if self.set_grad_none: if self.set_grad_none:
...@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0].item()
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
...@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm,
self.use_nvlamb)
if(len(g_32) > 0): if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf, self._dummy_overflow_buf,
...@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm,
self.use_nvlamb)
return loss return loss
...@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float step_size); const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda( void multi_tensor_adam_cuda(
int chunk_size, int chunk_size,
...@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda( ...@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float max_grad_norm); const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
...@@ -52,8 +52,8 @@ struct LAMBStage1Functor ...@@ -52,8 +52,8 @@ struct LAMBStage1Functor
const float epsilon, const float epsilon,
adamMode_t mode, adamMode_t mode,
const float decay, const float decay,
float* global_grad_norm, const float global_grad_norm,
float max_global_grad_norm) const float max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -63,7 +63,7 @@ struct LAMBStage1Functor ...@@ -63,7 +63,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f; float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
...@@ -239,7 +239,9 @@ struct LAMBStage2Functor ...@@ -239,7 +239,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl, TensorListMetadata<2>& tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate) const float learning_rate,
const float decay,
bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -250,9 +252,15 @@ struct LAMBStage2Functor ...@@ -250,9 +252,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num]; MATH_T ratio = learning_rate;
float update_norm = per_tensor_update_norm[tensor_num]; // nvlamb: apply adaptive learning rate to all parameters
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc]; T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
...@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda( ...@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float max_grad_norm) const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{ {
using namespace at; using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this // Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type // So we assume every tensor are all in the same type
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda( ...@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm // Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
...@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda( ...@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon, epsilon,
(adamMode_t) mode, (adamMode_t) mode,
weight_decay, weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(), global_grad_norm,
max_grad_norm); ) max_grad_norm); )
// Compute update norms // Compute update norms
...@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda( ...@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor<scalar_t_0>(), LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(), std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(), std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr); ) lr,
weight_decay,
use_nvlamb); )
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
using MATH_T = float;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value. // It computes new parameter value.
template<typename T, typename UPD_T> template<typename T, typename UPD_T>
...@@ -24,7 +26,9 @@ struct LAMBStage2Functor ...@@ -24,7 +26,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl, TensorListMetadata<2>& tl,
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate) const float learning_rate,
const float decay,
bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -35,9 +39,15 @@ struct LAMBStage2Functor ...@@ -35,9 +39,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num]; MATH_T ratio = learning_rate;
float update_norm = per_tensor_update_norm[tensor_num]; // nvlamb: apply adaptive learning rate to all parameters
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; // otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl.addresses[0][tensor_loc]; T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
...@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate) const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python)
{ {
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
...@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(), LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(), per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(), per_tensor_update_norm.DATA_PTR<float>(),
learning_rate); )) lr,
weight_decay,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment