1. 06 Mar, 2025 1 commit
    • Yu Cheng's avatar
      [Dev][Benchmark] Add MLA paged decoding example and benchmark script (#158) · be9abf18
      Yu Cheng authored
      * [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16
      
      - Remove redundant `acc_s_0` fragment in flash attention kernel
      - Simplify memory copy and reduction operations
      - Reorder memory copy and scaling steps for improved performance
      - Add Hopper-specific synchronization method in CUDA reduce template
      - Update reduce operation to use architecture-specific synchronization
      
      * [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script
      
      - Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script
      - Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton
      - Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations
      - Implement performance comparison and CSV output for detailed performance analysis
      - Add command-line argument support for targeted benchmarking and comparison
      
      * [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision
      
      - Replace `d` parameter with `dv` to clarify value dimension in MLA decoding
      - Enhance block distribution logic for split KV processing
      - Improve handling of remaining blocks in split KV computation
      - Add initialization of `lse_max_local` to prevent potential precision issues
      - Optimize block start and range calculations for more accurate sequence processing
      
      * lint
      be9abf18