Commit 90729bc8 authored by Kexin Yu's avatar Kexin Yu
Browse files

fix parameter type

parent 32d2c4e2
......@@ -41,7 +41,7 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
float* global_grad_norm,
float global_grad_norm,
float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
......@@ -52,7 +52,7 @@ struct LAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment