Unverified Commit 8abb6908 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Merge pull request #819 from kexinyu/master

Use global gradient clipping in FusedLAMB & add option for using NVLAMB
parents 3bae8c83 bd6e66df
......@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
......@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
max_grad_norm=1.0, use_nvlamb=False):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
......@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self):
if self.set_grad_none:
......@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0].item()
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
......@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
......@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
return loss
......@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda(
int chunk_size,
......@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm);
const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
......@@ -52,8 +52,8 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
float* global_grad_norm,
float max_global_grad_norm)
const float global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -63,7 +63,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
......@@ -239,7 +239,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -250,9 +252,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
......@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm)
const float global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
......@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon,
(adamMode_t) mode,
weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(),
global_grad_norm,
max_grad_norm); )
// Compute update norms
......@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr); )
lr,
weight_decay,
use_nvlamb); )
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512
#define ILP 4
using MATH_T = float;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename UPD_T>
......@@ -24,7 +26,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -35,9 +39,15 @@ struct LAMBStage2Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
......@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python)
{
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
......@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(),
learning_rate); ))
lr,
weight_decay,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment