Unverified Commit 468b1b70 authored by pengxin99's avatar pengxin99 Committed by GitHub
Browse files

RMSNorm epsilon refine in the example (#1243)

* Fix division by zero in RMS normalization

* Fix rsqrt calculation to avoid division by zero
parent 6882bd50
......@@ -21,7 +21,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
......@@ -22,7 +22,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment