Commit abc991da authored by Kexin Yu's avatar Kexin Yu
Browse files

fix dtype

parent f54cc1c9
...@@ -52,7 +52,7 @@ struct LAMBStage1Functor ...@@ -52,7 +52,7 @@ struct LAMBStage1Functor
const float epsilon, const float epsilon,
adamMode_t mode, adamMode_t mode,
const float decay, const float decay,
at::Tensor global_grad_norm, const float global_grad_norm,
const float max_global_grad_norm) const float max_global_grad_norm)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
...@@ -387,7 +387,7 @@ void multi_tensor_lamb_cuda( ...@@ -387,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon, epsilon,
(adamMode_t) mode, (adamMode_t) mode,
weight_decay, weight_decay,
global_grad_norm, global_grad_norm.data(),
max_grad_norm); ) max_grad_norm); )
// Compute update norms // Compute update norms
......
...@@ -123,7 +123,8 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -123,7 +123,8 @@ void multi_tensor_lamb_stage1_cuda(
{ {
using namespace at; using namespace at;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; auto g_grad_norm = global_grad_norm.data();
float clipped_global_grad_norm = g_grad_norm > max_global_grad_norm ? g_grad_norm / max_global_grad_norm : 1.0f;
float next_step = float(step+1); float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step); float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment