"packaging/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a78d0d83d0a499fe8480d7a9f493676e746c4699"
Commit fe0de672 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev][Bugfix] Add RMS Normalization Kernels and Fix Reduce Bug (#188)

* [Dev][Bugfix] Add RMS Normalization Kernels and Fix Reduce Bug

- Implement two RMS normalization implementations in TileLang:
  * `rms_norm_splitk`: Split-K reduction approach for large matrices
  * `rms_norm`: Full reduction kernel with simplified implementation
- Add reference implementation using PyTorch for validation
- Include performance benchmarking for both kernel variants
- Demonstrate flexible block size and matrix size configurations

* [Examples] Simplify RMS Normalization Kernel Compilation

- Remove commented-out code for split-K RMS normalization
- Simplify kernel compilation by removing explicit TMA lowering configuration
- Update copyright header to Tile-AI Corporation
- Streamline main script for RMS normalization example
parent d34601ab
import torch
import tilelang
import tilelang.language as T
def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"
@T.prim_func
def main(A: T.Buffer((M, N), dtype), B: T.Buffer((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, blk_k), dtype)
A_local = T.alloc_fragment((blk_m, blk_k), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
num_k_step = T.ceildiv(N, blk_k)
T.clear(A_local)
for k in range(num_k_step):
T.copy(A[bx * blk_m, k * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for k in range(num_k_step):
# reverse, better cache hit rate
T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_shared[i, j] *= A_powsum[i]
T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k])
return main
def rms_norm(M, N, blk_m):
dtype = "float"
@T.prim_func
def main(A: T.Buffer((M, N), dtype), B: T.Buffer((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, N), dtype)
A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] = A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for i, j in T.Parallel(blk_m, N):
A_shared[i, j] *= A_powsum[i]
T.copy(A_shared, B[bx * blk_m:(bx + 1) * blk_m, :])
return main
def ref_program(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12)
if __name__ == "__main__":
M, N, blk_m, blk_k = 8192, 8192, 1, 512
program = rms_norm(M, N, blk_m)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench(profiler.mod, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
...@@ -24,7 +24,8 @@ struct MinOp { ...@@ -24,7 +24,8 @@ struct MinOp {
} }
}; };
template <class Reducer, int threads, int scale> struct AllReduce { template <class Reducer, int threads, int scale, int all_threads = threads>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32 or threads == 128 or threads == 64 or threads == 32 or
threads == 16 or threads == 8 or threads == 4 or threads == 2); threads == 16 or threads == 8 or threads == 4 or threads == 2);
...@@ -50,9 +51,9 @@ template <class Reducer, int threads, int scale> struct AllReduce { ...@@ -50,9 +51,9 @@ template <class Reducer, int threads, int scale> struct AllReduce {
static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) { static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(threads)); asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
red_buf[threadIdx.x] = x; red_buf[threadIdx.x] = x;
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(threads)); asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]); x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
...@@ -60,7 +61,8 @@ template <class Reducer, int threads, int scale> struct AllReduce { ...@@ -60,7 +61,8 @@ template <class Reducer, int threads, int scale> struct AllReduce {
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
} else { } else {
return AllReduce<Reducer, offset, scale>::run_hopper(x, red_buf); return AllReduce<Reducer, offset, scale, all_threads>::run_hopper(
x, red_buf);
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment