1. 26 Feb, 2025 1 commit
    • Lei Wang's avatar
      [Example] Update GEMM FP8 Example (#123) · 13f4b5c6
      Lei Wang authored
      * Add DeepSeek MLA decode example with Flash Attention implementation
      
      * Add GEMM SplitK and StreamK example implementations
      
      This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
      - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
      - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang
      
      Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.
      
      * Refactor GEMM SplitK and StreamK example implementations
      
      Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
      - Remove unused import (Profiler) in splitk example
      - Simplify line breaks and improve code readability
      - Standardize indentation and remove unnecessary whitespace
      - Optimize atomic add and copy operations for better clarity
      
      * Add block sparse attention benchmarks for multiple libraries
      
      This commit introduces comprehensive block sparse attention benchmarks for different libraries:
      - TileLang block sparse FMHA implementation
      - Triton block sparse FMHA implementation
      - PyTorch reference block sparse FMHA implementation
      - FlashAttention dense FMHA reference implementation
      
      The benchmarks include:
      - Configurable benchmark parameters (batch size, heads, sequence length, etc.)
      - Sparse mask generation using top-k and threshold methods
      - Performance measurement for different sparse attention configurations
      - Utility functions for mask generation and benchmarking
      
      * Refactor block sparse attention benchmarks with code style improvements
      
      - Add Ruff linter ignore comments to benchmark files
      - Improve code formatting and line breaks
      - Remove unused imports
      - Standardize print statement formatting
      - Enhance code readability across multiple library benchmarks
      
      * lint fix
      
      * Add CUDA atomic operations for BFLOAT16 and update function naming
      
      - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
      - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
      - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
      - Update kernel and language customization to use new function names
      - Add return type annotations in profiler module
      
      * lint fix
      
      * Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang
      
      This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
      - Group Query Attention (GQA) implementation
      - Flash Attention forward pass
      - Performance benchmarking
      - Configurable parameters for batch, heads, sequence length, and dimension
      - Autotuning support
      - Reference implementation comparison
      
      * Refactor IR lowering pipeline into modular phases
      
      This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
      - `LowerAndLegalize`: Handles initial IR legalization and transformation
      - `OptimizeForTarget`: Applies target-specific optimizations
      
      The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.
      
      * lintfix
      
      * nas kernel
      
      * Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates
      
      - Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
      - Increased default number of heads and selected blocks in TileLang NSA example
      - Added Ruff linter ignore comments to reference.py
      - Standardized function signatures and improved code readability across NSA implementations
      
      * Add utility math functions for integer operations
      
      - Implement `next_power_of_2()` to calculate the next power of 2 for an integer
      - Add `cdiv()` function for ceiling division of integers
      
      * Add utility math functions for integer operations
      
      - Implement `next_power_of_2()` to calculate the next power of 2 for an integer
      - Add `cdiv()` function for ceiling division of integers
      
      * Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation
      
      - Update flash attention kernel to support positional embeddings (PE)
      - Modify reference implementation to handle PE and group query attention
      - Increase default batch size and adjust benchmarking parameters
      - Improve kernel performance and readability
      - Add einops and torch operations for more flexible tensor manipulation
      
      * Update README.md with corrected Flash MLA Decoding example path
      
      - Modify the example link for Flash MLA Decoding to point to the correct directory
      - Ensure accurate navigation to the DeepSeek MLA decoding example
      13f4b5c6
  2. 25 Feb, 2025 1 commit
  3. 23 Feb, 2025 1 commit
    • Yu Cheng's avatar
      [Dev] Add MLA and GQA decode examples (#109) · 40faabb1
      Yu Cheng authored
      * [CI][Test] Add test cases for tilelang transform MultiVersionBuffer and WarpSpecialized
      
      * Relax the mismatch ratio restrictions in the flash_linear_attention and mha tests
      
      * [Dev] Add mha backward example
      
      * [Dev] Add mla decode example
      
      * bug fix
      
      * Add triton impl
      
      * Add gqa decode example
      
      * [Dev] Add GQA decode example
      
      * lint
      
      * delete unused triton example
      
      * set default profiler to 'auto'
      40faabb1
  4. 25 Jan, 2025 2 commits
    • Yu Cheng's avatar
      [CI][Test] Add test cases for tilelang kernel FlashAttention (#54) · bedab1a0
      Yu Cheng authored
      * [Dev] Add FlashDecoding example
      
      * [CI][Test] Add test cases for tilelang kernel convolution
      
      * [CI][Test] Add test cases for tilelang kernel FlashAttention
      
      * Reduce the number of stages to ensure the shared memory allocation is valid
      
      * Temporarily remove the dim128 case
      
      * lint
      
      * update einops in requirements-dev.txt
      
      * update einops in requirements-test.txt
      
      * remove einops in requirements-dev.txt
      bedab1a0
    • Lei Wang's avatar
      [Doc] Remove unnecessary layout annotation (#49) · 47ecc791
      Lei Wang authored
      * [Doc] Update documentation structure and content: add overview section, revise project name, and change theme to Furo
      
      * [Feature] Add device-side debug printing functions and integrate into kernel interface
      
      * lint fix
      
      * remove debug print
      
      * implement test for debug
      
      * lint fix
      
      * add some comments
      
      * Enhance fragment design and assert fragment print
      
      * enhance debug print
      
      * add test for msg
      
      * lint fix
      
      * format
      
      * add flash decoding exmaples
      
      * remove comment
      
      * test simplified
      47ecc791