Commit 2be773d3 authored by Kexin Yu's avatar Kexin Yu
Browse files

fix function signature

parent cf918ac1
......@@ -42,7 +42,7 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_global_grad_norm);
void multi_tensor_lamb_stage2_cuda(
......@@ -108,7 +108,7 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment