Commit 1529608b authored by PanZezhong's avatar PanZezhong
Browse files

issue/6/fix type convertion

parent a23c4d13
......@@ -22,7 +22,7 @@ __device__ void rmsnormBlock(
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms;
if (threadIdx.x == 0) {
rms = Tdata(rsqrtf(ss / Tcompute(dim) + epsilon));
rms = Tcompute(rsqrtf(ss / Tcompute(dim) + epsilon));
}
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment