• botbw's avatar
    [Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526) · be44758c
    botbw authored
    
    
    * [experimental] add a draft gemm_sp
    
    * [3rdparty] bump cutlass to v3.9.3
    
    * [lint] run format.sh
    
    * [chore] rebase
    
    * [chore] use abs path
    
    * [gemm_sp] add metadata layout
    
    * [ci] add more example
    
    * [lint] run format.sh
    
    * [chore] polish
    
    * [chore] move gemm_sp to experimental
    
    * [chore] polish
    
    * [lint] run format.sh
    
    * [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test
    
    * Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
    * Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
    * Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.
    
    * Implement Test
    
    * [Enhancement] Update GEMM SP and SM89 templates for improved functionality
    
    * Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
    * Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
    * Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
    * Added conditional inclusion for GEMM SP header based on CUDA architecture version.
    
    * lint fix
    
    * [gemm_sp] support more layout and data types
    
    * Enhancement: sync T.gemm_sp's layout inference with T.gemm
    
    * Enhancement: support more block_k in compress util
    
    * [Enhancement] enable block_k=64
    
    * [Lint] run format.sh
    
    * [Enhancement] compressor support more dtype
    
    * Enhancement: enable block_K=32
    
    * [Lint] format.sh
    
    * [Fixbug] fix shape
    
    * Refactor: sync gemm
    
    * [Enhancement] enable transpose
    
    * [Enhancement] enable fp8_e4m3
    
    * [Enhancement] enable int8
    
    * [Lint] run format.sh
    
    * [Benchmark] add gemm_sp benchmark
    
    * [Example] fix 256 threads hang
    
    * [CI] fix ci
    
    * [Chore] resolve gemini feedback
    
    * [Benchmark] increase search space
    
    * [Lint] format
    
    * [CI] skip sparse tensor core related tests as only sm90 is supported
    
    * [CI] pass local run
    
    * Update gemm_sm89.h
    
    * lint fix
    
    * lint fix
    
    * [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags
    
    - Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
    - Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
    - Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
    - Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
    - Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.
    
    * Update test_compress_utils.py
    
    ---------
    Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
    Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
    be44758c
test_tilelang_tilelibrary_gemm.py 7.28 KB