"examples/sampling/graphbolt/pyg/node_classification.py" did not exist on "494d4cd73f44e50800f8e0caeb8f5d3f7e6d449e"
Commit 47e3367f authored by Michael Carilli's avatar Michael Carilli
Browse files

Allow multi_tensor_lamb to update fp16 params

parent 04667139
...@@ -100,20 +100,20 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -100,20 +100,20 @@ void multi_tensor_lamb_stage1_cuda(
float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
using accscalar_t_0 = acc_type<scalar_t_0, true>; DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
LAMBStage1Functor<scalar_t_0, accscalar_t_0>(), LAMBStage1Functor<scalar_t_0, scalar_t_1>(),
per_tensor_decay.data<float>(), per_tensor_decay.data<float>(),
beta1, beta1,
beta2, beta2,
beta1_correction, beta1_correction,
beta2_correction, beta2_correction,
epsilon, epsilon,
clipped_global_grad_norm); ) clipped_global_grad_norm); ))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment