• Lei Wang's avatar
    [Refactor] Update barrier functions and add new example for GEMM with warp specialization (#456) · a91bc2a9
    Lei Wang authored
    * Add example for warp specialization with flash attention
    
    * Introduced a new example script `example_warp_specialize_flashmla.py` demonstrating flash attention using warp specialization in TileLang.
    * Implemented the `flashattn` function with shared memory allocation and memory barrier synchronization for improved performance.
    * Added a reference program for validation against PyTorch's implementation, including profiling for latency and performance metrics.
    * Removed the outdated `example_warp_specialize_mla.py` to streamline examples and focus on the new implementation.
    
    * Add memory barrier functions to builtin.py
    
    * Introduced `barrier_wait` and `barrier_arrive` functions for memory barrier synchronization.
    * Enhanced documentation with detailed docstrings for both functions, clarifying their usage and parameters.
    * The `barrier_wait` function serves as a wrapper for `mbarrier_wait_parity`, supporting parity values 0 and 1.
    * Improved code organization and readability by adding blank lines for better separation of logical sections.
    
    * Enhance code readability by adding blank lines in example_warp_specialize_flashmla.py and builtin.py
    
    * Added blank lines to improve code organization and separation of logical sections in `example_warp_specialize_flashmla.py`.
    * Included blank lines in `builtin.py` around the `wait_wgmma` and `barrier_wait` functions for better readability.
    
    * [Refactor] Update barrier functions and add new example for GEMM with warp specialization
    
    * Refactored memory barrier functions in `example_warp_specialize_flashmla.py` to use the new `barrier_wait` and `barrier_arrive` methods for improved clarity and consistency.
    * Introduced a new example script `example_warp_specialize_gemm_copy_gemm_0_1.py` demonstrating matrix multiplication with warp specialization and shared memory allocation.
    * Enhanced the `layout.cc` and `elem.cc` files to improve structural equality checks and error handling in copy operations.
    * Updated `warpgroup.py` to refine thread ID calculations for better performance in warp specialization scenarios.
    * Added new shuffle operations in `builtin.py` for enhanced functionality in parallel computations.
    
    * lint fix
    
    * Update loop variable checks in SIMT loop and buffer region validation
    
    * Modified checks in `elem.cc` to ensure loop variable sizes are less than or equal to source and destination range sizes for better error handling.
    * Adjusted assertions in `copy.py` to reflect the updated logic, allowing for more flexible region extent comparisons and improved error messaging.
    
    * lint fix
    
    * test fix
    a91bc2a9
parallel.cc 12.8 KB