Commit 2c3f3d9a authored by Kexin Yu's avatar Kexin Yu
Browse files

at::Tensor::data_ptr()

parent abc991da
......@@ -387,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon,
(adamMode_t) mode,
weight_decay,
global_grad_norm.data(),
global_grad_norm.data_ptr<float>(),
max_grad_norm); )
// Compute update norms
......
......@@ -123,7 +123,7 @@ void multi_tensor_lamb_stage1_cuda(
{
using namespace at;
auto g_grad_norm = global_grad_norm.data();
auto g_grad_norm = global_grad_norm.data_ptr<float>();
float clipped_global_grad_norm = g_grad_norm > max_global_grad_norm ? g_grad_norm / max_global_grad_norm : 1.0f;
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment