Commit c8bcfff8 authored by Kexin Yu's avatar Kexin Yu
Browse files

fix function signature for LAMBStage2Functor

parent 5b300119
...@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -51,7 +51,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float step_size, const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python); at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda( void multi_tensor_adam_cuda(
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
using MATH_T = float;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value. // It computes new parameter value.
template<typename T, typename UPD_T> template<typename T, typename UPD_T>
...@@ -25,6 +27,7 @@ struct LAMBStage2Functor ...@@ -25,6 +27,7 @@ struct LAMBStage2Functor
const float* per_tensor_param_norm, const float* per_tensor_param_norm,
const float* per_tensor_update_norm, const float* per_tensor_update_norm,
const float learning_rate, const float learning_rate,
const float decay,
bool use_nvlamb) bool use_nvlamb)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
...@@ -94,7 +97,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -94,7 +97,8 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm, at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm, at::Tensor per_tensor_update_norm,
const float learning_rate, const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python) at::optional<bool> use_nvlamb_python)
{ {
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
...@@ -111,7 +115,8 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -111,7 +115,8 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(), LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(), per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(), per_tensor_update_norm.DATA_PTR<float>(),
learning_rate, lr,
weight_decay,
use_nvlamb); )) use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment