Commit c8bcfff8 authored by Kexin Yu's avatar Kexin Yu
Browse files

fix function signature for LAMBStage2Functor

parent 5b300119
......@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size,
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda(
......
......@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512
#define ILP 4
using MATH_T = float;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename UPD_T>
......@@ -25,6 +27,7 @@ struct LAMBStage2Functor
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
......@@ -94,7 +97,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python)
{
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
......@@ -111,7 +115,8 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(),
learning_rate,
lr,
weight_decay,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment