• Yu Cheng's avatar
    [Dev][Benchmark] Add MLA paged decoding example and benchmark script (#158) · be9abf18
    Yu Cheng authored
    * [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16
    
    - Remove redundant `acc_s_0` fragment in flash attention kernel
    - Simplify memory copy and reduction operations
    - Reorder memory copy and scaling steps for improved performance
    - Add Hopper-specific synchronization method in CUDA reduce template
    - Update reduce operation to use architecture-specific synchronization
    
    * [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script
    
    - Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script
    - Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton
    - Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations
    - Implement performance comparison and CSV output for detailed performance analysis
    - Add command-line argument support for targeted benchmarking and comparison
    
    * [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision
    
    - Replace `d` parameter with `dv` to clarify value dimension in MLA decoding
    - Enhance block distribution logic for split KV processing
    - Improve handling of remaining blocks in split KV computation
    - Add initialization of `lse_max_local` to prevent potential precision issues
    - Optimize block start and range calculations for more accurate sequence processing
    
    * lint
    be9abf18
benchmark_mla.py 19.7 KB