• Lei Wang's avatar
    [Refactor] Refactor barrier management (#744) · cb37bfef
    Lei Wang authored
    * Introduce Barrier
    
    * Enhance CUDA kernel with new barrier management and post-processing support
    
    - Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers.
    - Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure.
    - Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency.
    - Introduced additional print statements for debugging in the lowering phase of the TileLang engine.
    - Enhanced the overall structure and readability of the codebase.
    
    * Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic.
    
    * Enhance barrier management in TileLang
    
    - Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework.
    - Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory.
    - Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code.
    - Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine.
    - Removed deprecated memory scope handling code to enhance clarity and maintainability.
    
    * lint fix
    
    * lint fix
    
    * Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability.
    
    * Refactor logging in JITKernel to improve kernel compilation tracking
    
    - Removed unused import of `torch.backends` in the example file.
    - Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging.
    - Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function.
    
    * Refactor dequantization tests and update barrier function
    
    - Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite.
    - Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management.
    
    * Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed.
    
    * Fix typos in rasterization parameters and update import path for cached module
    
    - Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage.
    - Updated the import statement for the `cached` module to reflect the new path in the cache submodule.
    - Added `StridedTensor` import in the language module for enhanced tensor functionality.
    
    * Update ci.yml
    cb37bfef
common.h 8.52 KB