Commit f54cc1c9 authored by Kexin Yu's avatar Kexin Yu
Browse files

make fused LAMB async

parent 8abb6908
......@@ -132,7 +132,7 @@ class FusedLAMB(torch.optim.Optimizer):
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0].item()
False)[0]
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
......
......@@ -52,7 +52,7 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
......@@ -342,7 +342,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{
......
......@@ -118,7 +118,7 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment