• Lei Wang's avatar
    [Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746) · 5c11d245
    Lei Wang authored
    * [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy
    
    * Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed.
    * Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping.
    * Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable.
    
    * lint fix
    
    * Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency.
    
    * remove bulk copy
    
    * Refactor copy and atomic add operations to support TMA lower configuration
    
    - Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations.
    - Modified `Lower` method in `Copy` to incorporate the new TMA configuration.
    - Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic.
    - Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity.
    - Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach.
    
    * Enhance TMA bulk copy logic in `LowerBulkCopy` method
    
    - Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling.
    - Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities.
    
    * lint fix
    
    * Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions.
    
    * Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions
    
    - Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations.
    - Added TODO comments to indicate the need for further improvements in shared memory handling.
    
    * Update `native_sparse_attention` function to include TMA configuration options
    
    - Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations.
    - Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1.
    
    * Refactor JIT decorator formatting in `native_sparse_attention` function
    
    - Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase.
    - No functional changes were made; this update focuses on code clarity and maintainability.
    
    * Enhance thread management and logging in TileLang compilation
    
    - Added a method to check if printing is enabled during compilation, improving control over logging behavior.
    - Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output.
    - Added comments to clarify the purpose of changes and improve code readability.
    
    * Add warp specialization scope and refactor register management in TileLang
    
    - Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management.
    - Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process.
    - Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management.
    - Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process.
    
    * Refactor test for InjectSetMaxNReg pass in TileLang
    
    - Improved readability by restructuring conditional checks and assertions in the test cases.
    - Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic.
    - Ensured consistent formatting and spacing throughout the test functions for better maintainability.
    
    * Enhance bulk copy and store checks in `Copy` class
    
    - Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options.
    - Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations.
    - Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.
    
    * lint fix
    5c11d245
builtin.cc 6.32 KB