Commit 85e411c8 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Refactor] Optimize RMS normalization kernel in rms_norm.py (#333)

- Introduced a new local fragment for squared values to improve performance.
- Updated the computation of the RMS normalization to use the new fragment, enhancing memory efficiency.
- Refactored the final multiplication step to operate on the local fragment instead of shared memory.
- Added a configuration option to the kernel compilation for better control over TMA lowering.

These changes enhance the efficiency and clarity of the RMS normalization implementation.
parent 9e5a757e
...@@ -40,18 +40,20 @@ def rms_norm(M, N, blk_m): ...@@ -40,18 +40,20 @@ def rms_norm(M, N, blk_m):
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, N), dtype) A_shared = T.alloc_shared((blk_m, N), dtype)
A_pow_local = T.alloc_fragment((blk_m, N), dtype)
A_local = T.alloc_fragment((blk_m, N), dtype) A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype) A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared) T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_local[i, j] = A_shared[i, j] * A_shared[i, j] A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_local, A_powsum, dim=1) T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m): for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12 A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for i, j in T.Parallel(blk_m, N): for i, j in T.Parallel(blk_m, N):
A_shared[i, j] *= A_powsum[i] A_local[i, j] *= A_powsum[i]
T.copy(A_shared, B[bx * blk_m:(bx + 1) * blk_m, :]) T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
return main return main
...@@ -63,7 +65,12 @@ def ref_program(x): ...@@ -63,7 +65,12 @@ def ref_program(x):
if __name__ == "__main__": if __name__ == "__main__":
M, N, blk_m, blk_k = 8192, 8192, 1, 512 M, N, blk_m, blk_k = 8192, 8192, 1, 512
program = rms_norm(M, N, blk_m) program = rms_norm(M, N, blk_m)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") kernel = tilelang.compile(
program,
out_idx=-1,
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment