Commit 32d2c4e2 authored by Kexin Yu's avatar Kexin Yu
Browse files

clip gradients globally, rather than per group

parent 8405d436
...@@ -13,6 +13,7 @@ void multi_tensor_lamb_cuda( ...@@ -13,6 +13,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float global_grad_norm,
const float max_grad_norm); const float max_grad_norm);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -227,6 +227,7 @@ void multi_tensor_lamb_cuda( ...@@ -227,6 +227,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float global_grad_norm,
const float max_grad_norm) const float max_grad_norm)
{ {
using namespace at; using namespace at;
...@@ -247,9 +248,6 @@ void multi_tensor_lamb_cuda( ...@@ -247,9 +248,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm // Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
...@@ -271,7 +269,7 @@ void multi_tensor_lamb_cuda( ...@@ -271,7 +269,7 @@ void multi_tensor_lamb_cuda(
epsilon, epsilon,
(adamMode_t) mode, (adamMode_t) mode,
weight_decay, weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(), global_grad_norm,
max_grad_norm); ) max_grad_norm); )
// Compute update norms // Compute update norms
......
import torch import torch
import importlib import importlib
import math
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer): class FusedLAMB(torch.optim.Optimizer):
...@@ -100,6 +101,30 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -100,6 +101,30 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dytpe == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
# compute grad norm for two lists
g_norm_32, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)
g_norm_16, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
...@@ -156,7 +181,8 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -156,7 +181,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm)
if(len(g_32) > 0): if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb, multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf, self._dummy_overflow_buf,
...@@ -170,6 +196,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -170,6 +196,7 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'], group['weight_decay'],
grad_averaging, grad_averaging,
self.adam_w_mode, self.adam_w_mode,
group['max_grad_norm']) global_grad_norm,
max_grad_norm)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment