• Yu Cheng's avatar
    [Dev][Doc] Add DeepSeek MLA Decode Example with Documentation and Performance Benchmarks (#134) · cd94aca1
    Yu Cheng authored
    * [Dev] Add RetNet Linear Attention example
    
    * [Dev] Add WgmmaSync rewriter for pipelined WGMMA operations and add MHA WGMMA pipelined example (FA3-like scheduling)
    
    This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler:
    
    - Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc`
    - Added pass registration for `RewriteWgmmaSync`
    - Updated `tilelang/engine/phase.py` to include the new transformation pass
    - Updated `tilelang/transform/__init__.py` to expose the new pass
    
    The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels.
    
    * [Bugfix] Fix bug in ThreadTagChecker for warp specialization
    
    Improve thread tag validation in warp specialized rewriter to prevent unintended transformations:
    - Add more precise checks for threadIdx.y and threadIdx.z
    - Validate thread extent to ensure only single-extent thread bindings are allowed
    - Prevent warp specialization for multi-extent thread bindings in y and z dimensions
    
    * lint
    
    * [CI] Add TMA descriptor attribute to transformed module in test case
    
    * [Dev] Refactor DeepSeek MLA Decode Example with Non-Split and Split Flash Attention Implementations
    
    - Add new `flash_attn` macro for non-split flash attention implementation
    - Add swizzled layout for tile in shared memory
    - Use threadblock swizzle to imporve L2 cache hit rate
    
    * [Dev] Add DeepSeek MLA Decode Example with Documentation and Performance Benchmarks
    
    - Add detailed README.md explaining MLA (Multi-Head Latent Attention) implementation
    - Include performance benchmark images for batch sizes 64 and 128
    - Add layout visualization images for QK and PV operations
    - Implement torch reference implementations in torch_refs.py
    - Update example_mla_decode.py with command-line argument support and flexible configuration
    - Add performance benchmarking and comparison with other implementations
    cd94aca1
example_mla_decode.py 14.7 KB