• Lei Wang's avatar
    [Enhancement] Add new examples for warp specialization and TMA integration (#448) · b5faf25a
    Lei Wang authored
    * [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic
    
    * Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
    * Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.
    
    * [Refactor] Rename operations for consistency in lower_hopper_intrin and related files
    
    * Updated function names from CamelCase to snake_case for better consistency across the codebase.
    * Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
    * Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.
    
    * [Refactor] Rename operations to snake_case for consistency
    
    * Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
    * Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
    * Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.
    
    * [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier
    
    * Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
    * Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
    * Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
    * Enhanced the TileLang API with new methods for retrieving block and thread extents.
    * Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
    * Improved layout inference and kernel launch logic for better performance and clarity.
    
    * [Refactor] Clean up code formatting and improve readability
    
    * Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
    * Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
    * Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
    * Ensured consistent spacing and formatting across multiple files to enhance overall code readability.
    
    * lint fix
    
    * [Refactor] Update mbarrier functions for improved clarity and consistency
    
    * Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
    * Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
    * Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
    * Added detailed docstrings to clarify usage examples for memory barrier functions.
    
    * Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.
    
    * [Feature] Add examples for warp specialization and TMA barrier integration
    
    * Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
    * Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
    * Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
    * Updated the `phase.py` to include TMA barrier injection in the optimization process.
    * Improved documentation and comments for better clarity on usage and functionality.
    
    * [Feature] Add example for warp specialization in GEMM with TMA barriers
    
    * Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
    * Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
    * Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
    * Enhanced documentation and comments for clarity on usage and functionality.
    
    * lint fix
    
    * [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection
    
    * Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
    * Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
    * Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
    * This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.
    
    * lint fix
    
    * [Feature] Add new examples for warp specialization and TMA integration
    
    * Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
    * Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
    * Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
    * Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
    * Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.
    
    * lint fix
    
    * Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
    b5faf25a
example_warp_specialize_flashmla.py 8.66 KB