Unverified Commit 36c9e904 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Merge pull request #851 from kexinyu/master

make FusedLAMB async
parents 87aca22a 2be773d3
...@@ -132,7 +132,7 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -132,7 +132,7 @@ class FusedLAMB(torch.optim.Optimizer):
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf, self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]], [[g_norm_32, g_norm_16]],
False)[0].item() False)[0]
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups: for group in self.param_groups:
......
...@@ -42,7 +42,7 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -42,7 +42,7 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1, const float beta1,
const float beta2, const float beta2,
const float epsilon, const float epsilon,
const float global_grad_norm, at::Tensor global_grad_norm,
const float max_global_grad_norm); const float max_global_grad_norm);
void multi_tensor_lamb_stage2_cuda( void multi_tensor_lamb_stage2_cuda(
...@@ -108,7 +108,7 @@ void multi_tensor_lamb_cuda( ...@@ -108,7 +108,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float global_grad_norm, at::Tensor global_grad_norm,
const float max_grad_norm, const float max_grad_norm,
at::optional<bool> use_nvlamb_python); at::optional<bool> use_nvlamb_python);
......
...@@ -52,7 +52,7 @@ struct LAMBStage1Functor ...@@ -52,7 +52,7 @@ struct LAMBStage1Functor
const float epsilon, const float epsilon,
adamMode_t mode, adamMode_t mode,
const float decay, const float decay,
const float global_grad_norm, const float* global_grad_norm,
const float max_global_grad_norm) const float max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
...@@ -63,7 +63,7 @@ struct LAMBStage1Functor ...@@ -63,7 +63,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
...@@ -342,7 +342,7 @@ void multi_tensor_lamb_cuda( ...@@ -342,7 +342,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay, const float weight_decay,
const int grad_averaging, const int grad_averaging,
const int mode, const int mode,
const float global_grad_norm, at::Tensor global_grad_norm,
const float max_grad_norm, const float max_grad_norm,
at::optional<bool> use_nvlamb_python) at::optional<bool> use_nvlamb_python)
{ {
...@@ -387,7 +387,7 @@ void multi_tensor_lamb_cuda( ...@@ -387,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon, epsilon,
(adamMode_t) mode, (adamMode_t) mode,
weight_decay, weight_decay,
global_grad_norm, global_grad_norm.DATA_PTR<float>(),
max_grad_norm); ) max_grad_norm); )
// Compute update norms // Compute update norms
......
...@@ -118,12 +118,13 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -118,12 +118,13 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1, const float beta1,
const float beta2, const float beta2,
const float epsilon, const float epsilon,
const float global_grad_norm, at::Tensor global_grad_norm,
const float max_global_grad_norm) const float max_global_grad_norm)
{ {
using namespace at; using namespace at;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; const float* g_grad_norm = global_grad_norm.DATA_PTR<float>();
float clipped_global_grad_norm = *(g_grad_norm) > max_global_grad_norm ? *(g_grad_norm) / max_global_grad_norm : 1.0f;
float next_step = float(step+1); float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment