• Lei Wang's avatar
    [Example] Implement NSA Decode tilelang exampls (#168) · 69f35439
    Lei Wang authored
    * [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation
    
    - Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
    - Modify roller hints generation using new TileLang Carver template and utility functions
    - Update get_roller_hints_from_func to handle None cases and improve return logic
    - Adjust DefaultPolicy to handle different codegen dictionary formats
    
    * [Refactor] Update Thread Binding and Import Statements in TileLang Kernels
    
    - Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
    - Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
    - Move map_torch_type utility function to tilelang.utils.tensor
    - Remove unnecessary imports and improve code organization
    
    * Refactor Native Sparse Attention Example with Enhanced Triton Kernel
    
    - Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation
    - Add support for block counts and offsets in the Triton kernel
    - Modify kernel grid and computation logic for improved performance
    - Update example script to use naive_nsa_simple reference implementation
    - Improve type hints and kernel configuration
    
    * Add Native Sparse Attention Examples with Tilelang and Triton Implementations
    
    - Introduce new example scripts for native sparse attention:
      * example_tilelang_nsa_fwd.py: Forward pass implementation using TileLang
      * example_tilelang_nsa_decode.py: Decoding-specific sparse attention implementation
      * example_triton_nsa_fwd.py: Triton-based sparse attention forward pass
    - Update reference.py with naive implementations for sparse attention
    - Support different sparse attention scenarios including forward pass and inference
    - Add comprehensive testing and validation against reference implementations
    
    * lint fix
    69f35439
example_tilelang_nsa_decode.py 6.96 KB