Commit c9e503be authored by alex_xiao's avatar alex_xiao Committed by LeiWang1999
Browse files

[CI]Add norm and layout_plot (#534)



* [CI]Add norm and layout_plot

* fix lint

* Remove obsolete test files for RMS normalization and plot layout, streamlining the testing suite.

* Add make_mma_load_base_layout function to create MMA result layouts

- Introduced a new function `make_mma_load_base_layout` for generating layout functions for storing MMA results in fragment buffers.
- Added detailed docstring explaining parameters, return values, and potential exceptions.
- Implemented logic for handling different data types and matrix configurations, including assertions for input validation.
- Defined internal functions for mapping fragment indices to threads and local indices, enhancing the layout functionality.

* Enhance MMA load test with additional imports and functionality

- Added imports for `tilelang.language`, `Literal`, `Callable`, `DataType`, `IndexMap`, and `get_mma_micro_size` to support extended functionality.
- Improved the `make_mma_load_base_layout` function by ensuring it can handle various data types and configurations.
- Updated the test function `test_mma_load_base_layout` to validate the layout for float16 matrix A.

* Fix formatting in test_fragment_mma_load_a.py by adding a blank line for improved readability.

* Add RMS normalization functions to test_rms_norm.py

- Introduced `rms_norm` and `rms_norm_splitk` functions for RMS normalization, enhancing the testing capabilities.
- Implemented kernel functions with shared memory allocation and parallel processing for improved performance.
- Updated the test function to validate the new RMS normalization implementations.

* Add reference program for RMS normalization in test_rms_norm.py

- Introduced `ref_program` function to provide a reference implementation for RMS normalization.
- This addition enhances the testing framework by allowing comparisons against a known reference output.

* Enhance RMS normalization tests with additional imports and formatting

- Added import for `tilelang.language` to support extended functionality in `test_rms_norm.py`.
- Improved code readability by adding blank lines for better separation of code sections.

* Update RMS normalization test parameters and enhance layout plotting

- Increased matrix dimensions in `test_rms_norm` to 8192 for improved performance testing.
- Removed obsolete test functions in `test_fragment_mma_load_a.py` to streamline the test suite.
- Enhanced layout plotting functionality by ensuring proper visualization of base, warp, and block layouts in `test_fragment_mma_load_a.py`.

* Refactor RMS normalization test parameters and improve layout plotting readability

- Simplified the parameters in `test_rms_norm` by removing `blk_k` for clarity.
- Enhanced code readability in `test_fragment_mma_load_a.py` by adjusting the formatting of the `block_layout` definition and removing the unused `warp_cols` variable.

* Enhance RMS normalization with split-k implementation and additional profiling

- Added a new function `test_rms_norm_splitk` to test the split-k variant of RMS normalization.
- Updated the main RMS normalization script to include profiling for the split-k implementation.
- Ensured all checks pass with appropriate latency measurements for both reference and tile-lang implementations.

* Remove obsolete test file `test_fragment_mma_load_a.py` to streamline the test suite.

* Refactor `rms_norm.py` to streamline benchmarking output and remove redundant code. Comment out the `plot_layout` call in `fragment_mma_load_a.py` for clarity.

* Refactor `test_rms_norm.py` by removing redundant test function `test_rms_norm_splitk` to streamline the test suite and improve clarity.

---------
Co-authored-by: default avatarYour Name <you@example.com>
parent eec07578
import torch
import tilelang
import tilelang.language as T
def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, blk_k), dtype)
A_local = T.alloc_fragment((blk_m, blk_k), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
num_k_step = T.ceildiv(N, blk_k)
T.clear(A_local)
for k in range(num_k_step):
T.copy(A[bx * blk_m, k * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for k in range(num_k_step):
# reverse, better cache hit rate
T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_shared[i, j] *= A_powsum[i]
T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k])
return main
def rms_norm(M, N, blk_m):
dtype = "float"
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, N), dtype)
A_pow_local = T.alloc_fragment((blk_m, N), dtype)
A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)
T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
return main
def ref_program(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12)
def test_rms_norm():
M, N, blk_m = 8192, 8192, 1
program = rms_norm(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=-1,
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment