Commit 47e3367f authored by Michael Carilli's avatar Michael Carilli
Browse files

Allow multi_tensor_lamb to update fp16 params

parent 04667139
......@@ -100,20 +100,20 @@ void multi_tensor_lamb_stage1_cuda(
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
using accscalar_t_0 = acc_type<scalar_t_0, true>;
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, accscalar_t_0>(),
LAMBStage1Functor<scalar_t_0, scalar_t_1>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); )
clipped_global_grad_norm); ))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment