• alex_xiao's avatar
    [CI]Add norm and layout_plot (#534) · c9e503be
    alex_xiao authored
    
    
    * [CI]Add norm and layout_plot
    
    * fix lint
    
    * Remove obsolete test files for RMS normalization and plot layout, streamlining the testing suite.
    
    * Add make_mma_load_base_layout function to create MMA result layouts
    
    - Introduced a new function `make_mma_load_base_layout` for generating layout functions for storing MMA results in fragment buffers.
    - Added detailed docstring explaining parameters, return values, and potential exceptions.
    - Implemented logic for handling different data types and matrix configurations, including assertions for input validation.
    - Defined internal functions for mapping fragment indices to threads and local indices, enhancing the layout functionality.
    
    * Enhance MMA load test with additional imports and functionality
    
    - Added imports for `tilelang.language`, `Literal`, `Callable`, `DataType`, `IndexMap`, and `get_mma_micro_size` to support extended functionality.
    - Improved the `make_mma_load_base_layout` function by ensuring it can handle various data types and configurations.
    - Updated the test function `test_mma_load_base_layout` to validate the layout for float16 matrix A.
    
    * Fix formatting in test_fragment_mma_load_a.py by adding a blank line for improved readability.
    
    * Add RMS normalization functions to test_rms_norm.py
    
    - Introduced `rms_norm` and `rms_norm_splitk` functions for RMS normalization, enhancing the testing capabilities.
    - Implemented kernel functions with shared memory allocation and parallel processing for improved performance.
    - Updated the test function to validate the new RMS normalization implementations.
    
    * Add reference program for RMS normalization in test_rms_norm.py
    
    - Introduced `ref_program` function to provide a reference implementation for RMS normalization.
    - This addition enhances the testing framework by allowing comparisons against a known reference output.
    
    * Enhance RMS normalization tests with additional imports and formatting
    
    - Added import for `tilelang.language` to support extended functionality in `test_rms_norm.py`.
    - Improved code readability by adding blank lines for better separation of code sections.
    
    * Update RMS normalization test parameters and enhance layout plotting
    
    - Increased matrix dimensions in `test_rms_norm` to 8192 for improved performance testing.
    - Removed obsolete test functions in `test_fragment_mma_load_a.py` to streamline the test suite.
    - Enhanced layout plotting functionality by ensuring proper visualization of base, warp, and block layouts in `test_fragment_mma_load_a.py`.
    
    * Refactor RMS normalization test parameters and improve layout plotting readability
    
    - Simplified the parameters in `test_rms_norm` by removing `blk_k` for clarity.
    - Enhanced code readability in `test_fragment_mma_load_a.py` by adjusting the formatting of the `block_layout` definition and removing the unused `warp_cols` variable.
    
    * Enhance RMS normalization with split-k implementation and additional profiling
    
    - Added a new function `test_rms_norm_splitk` to test the split-k variant of RMS normalization.
    - Updated the main RMS normalization script to include profiling for the split-k implementation.
    - Ensured all checks pass with appropriate latency measurements for both reference and tile-lang implementations.
    
    * Remove obsolete test file `test_fragment_mma_load_a.py` to streamline the test suite.
    
    * Refactor `rms_norm.py` to streamline benchmarking output and remove redundant code. Comment out the `plot_layout` call in `fragment_mma_load_a.py` for clarity.
    
    * Refactor `test_rms_norm.py` by removing redundant test function `test_rms_norm_splitk` to streamline the test suite and improve clarity.
    
    ---------
    Co-authored-by: default avatarYour Name <you@example.com>
    c9e503be
rms_norm.py 2.88 KB