• Lei Wang's avatar
    [Warp Specialize] Implicit Warp Specialize Programing Model (#605) · e2d25ba8
    Lei Wang authored
    * [Enhancement] Improve memory access condition checks in GlobalMemChecker
    
    - Updated the condition checks in the GlobalMemChecker to utilize symbolic bounds in the CanProve method, enhancing the accuracy of memory access validations.
    - This change ensures that both upper and lower bound conditions are evaluated with improved proof strength, contributing to more robust memory access analysis.
    
    * lintfix
    
    * [Enhancement] Add legality checks for shared memory and global range in LowerBulkCopy
    
    - Implemented checks to ensure that the shared memory range and global range are legal during the bulk copy operation.
    - Added assertions to validate that the extents of global and shared ranges match, improving the robustness of memory access validation in the LowerBulkCopy function.
    
    * [Refactor] Update barrier and clear operations in warp specialization examples
    
    - Replaced `mbarrier_wait_parity` and `mbarrier_arrive` with `barrier_wait` and `barrier_arrive` for improved clarity and consistency in synchronization.
    - Adjusted the order of `clear` operations for local fragments in `example_warp_specialize_gemm_copy_1_gemm_0` to enhance parallel execution efficiency.
    
    * [Enhancement] Implement thread partial synchronization and improve shared memory allocation handling
    
    - Added support for thread partial barrier synchronization in CUDA, allowing for more flexible thread management.
    - Enhanced the `MergeSharedMemoryAllocations` function to accept alignment bytes, improving memory allocation efficiency based on target requirements.
    - Updated the `Lower` methods in `Copy` and `Fill` classes to include conditional predicates for thread execution, ensuring better control over thread behavior.
    - Refactored the `print` function to include warp group and warp IDs for more detailed debugging output.
    - Improved the handling of dynamic shared memory allocations in the `LowerAndLegalize` function to align with target-specific requirements.
    
    * [Enhancement] Add support for disabling TMA in Copy operations
    
    - Introduced a new `disable_tma` parameter in the `Copy` class to control thread memory access behavior.
    - Updated the `Lower` method to conditionally execute bulk copy operations based on the `disable_tma` flag.
    - Enhanced the `copy` function to accept the `disable_tma` argument, allowing for more flexible memory copy operations.
    - Improved handling of `coalesced_width` to ensure it defaults to -1 when not provided, enhancing robustness in memory operations.
    
    * [Refactor] Clean up whitespace and formatting in multiple files
    
    - Removed unnecessary blank lines and adjusted line breaks for improved code readability in `example_mla_decode.py`, `example_warp_specialize_gemm_copy_gemm_0_1.py`, `phase.py`, and `copy.py`.
    - Ensured consistent formatting across functions to enhance maintainability and clarity of the codebase.
    
    * [Enhancement] Refactor flash attention implementation for improved performance and configurability
    
    - Split the shared memory allocations for query and key-value pairs to optimize memory usage.
    - Introduced command-line arguments for batch size, number of heads, and dimensions, enhancing flexibility in running the example.
    - Updated kernel execution parameters to improve thread management and synchronization.
    - Enhanced the overall structure of the flash attention function for better readability and maintainability.
    
    * fix
    
    * Update layout inference in ParallelOp to account for thread bounds; remove debug print in OptimizeForTarget
    
    * Refactor barrier handling and update example configurations
    
    - Replaced commented-out barrier creation with new barrier allocation in GEMM example.
    - Updated kernel configuration in warp specialization example to include async copy settings.
    - Enhanced barrier management in the phase optimization process to improve synchronization handling.
    - Introduced new barrier allocation function for better memory management in shared contexts.
    
    * Refactor barrier handling in LowerAndLegalize and OptimizeForTarget
    
    - Reintroduced barrier lowering in OptimizeForTarget to enhance synchronization.
    - Removed commented-out barrier lowering in LowerAndLegalize for cleaner code.
    - Added exit() call in OptimizeForTarget to halt execution after barrier lowering.
    
    * Enhance CMake configuration and clean up example scripts
    
    - Enabled compile command export in CMakeLists.txt for better build integration.
    - Removed unnecessary print statement in the warp specialization example.
    - Cleaned up commented-out code in GEMM example for improved readability.
    - Updated barrier handling in shared memory allocation transformations for better synchronization.
    
    * Refactor barrier handling in warp specialization examples
    
    - Replaced commented-out mbarrier code with new barrier allocation using T.alloc_barrier for improved synchronization.
    - Updated barrier wait and arrive calls to align with the new allocation method across multiple example scripts.
    - Enhanced code readability by removing unnecessary comments and ensuring consistent barrier management.
    
    * Update lower_shared_barrier.cc
    
    * Update phase.py
    
    * Update warp specialization example and Cython wrapper
    
    - Removed commented-out pass configuration options in the warp specialization example for clarity.
    - Added functionality to write the generated kernel source to a file named "kernel.cu".
    - Enhanced Cython wrapper to support boolean type conversion for improved type handling.
    
    * Add storage synchronization call in shared barrier transformation
    
    - Introduced a new evaluation statement to call the TVM storage sync function with "shared" as an argument, enhancing synchronization in the shared barrier handling process.
    
    * remove debug files
    
    * Remove kernel source output to file in warp specialization example
    
    * remove comments
    
    * Refactor tensor handling and update test execution in TileLang
    
    - Changed `Buffer` to `Tensor` in `customize.py` for better type consistency.
    - Updated `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to use `tir.BufferLoad` instead of `BufferLoad`.
    - Commented out the main testing function in `test_tilelang_language_reshape.py` and replaced it with a direct call to `run_reshape_smem` for streamlined testing.
    - Removed unnecessary NVCC compiler flags in `libgen.py` to reduce verbosity.
    
    * Update test_tilelang_language_reshape.py
    e2d25ba8
reduce.h 4.37 KB