• Lei Wang's avatar
    [Benchmark] Add benchmark scripts for block sparse attention (#114) · f2f67571
    Lei Wang authored
    * Add DeepSeek MLA decode example with Flash Attention implementation
    
    * Add GEMM SplitK and StreamK example implementations
    
    This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
    - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
    - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang
    
    Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.
    
    * Refactor GEMM SplitK and StreamK example implementations
    
    Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
    - Remove unused import (Profiler) in splitk example
    - Simplify line breaks and improve code readability
    - Standardize indentation and remove unnecessary whitespace
    - Optimize atomic add and copy operations for better clarity
    
    * Add block sparse attention benchmarks for multiple libraries
    
    This commit introduces comprehensive block sparse attention benchmarks for different libraries:
    - TileLang block sparse FMHA implementation
    - Triton block sparse FMHA implementation
    - PyTorch reference block sparse FMHA implementation
    - FlashAttention dense FMHA reference implementation
    
    The benchmarks include:
    - Configurable benchmark parameters (batch size, heads, sequence length, etc.)
    - Sparse mask generation using top-k and threshold methods
    - Performance measurement for different sparse attention configurations
    - Utility functions for mask generation and benchmarking
    
    * Refactor block sparse attention benchmarks with code style improvements
    
    - Add Ruff linter ignore comments to benchmark files
    - Improve code formatting and line breaks
    - Remove unused imports
    - Standardize print statement formatting
    - Enhance code readability across multiple library benchmarks
    
    * lint fix
    f2f67571
This project manages its dependencies using pip. Learn more
requirements.txt 11 Bytes