Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
--- ---
InheritParentConfig: true InheritParentConfig: true
ExtraArgs: ['-v'] ExtraArgs: []
FormatStyle: file FormatStyle: file
UseColor: true UseColor: true
WarningsAsErrors: '*' WarningsAsErrors: '*'
......
...@@ -22,10 +22,12 @@ env: ...@@ -22,10 +22,12 @@ env:
PYTHONDEVMODE: "1" PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100" COLUMNS: "100"
FORCE_COLOR: "1" FORCE_COLOR: "1"
CLICOLOR_FORCE: "1" CLICOLOR_FORCE: "1"
UV_INDEX_STRATEGY: "unsafe-best-match" UV_INDEX_STRATEGY: "unsafe-best-match"
UV_HTTP_TIMEOUT: "600"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated
...@@ -44,7 +46,7 @@ jobs: ...@@ -44,7 +46,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Setup Python 3.8 - name: Setup Python 3.8
id: setup-py38 id: setup-pylowest
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.8" # use lowest supported version for linting python-version: "3.8" # use lowest supported version for linting
...@@ -52,7 +54,7 @@ jobs: ...@@ -52,7 +54,7 @@ jobs:
- name: Check AST with Python 3.8 - name: Check AST with Python 3.8
run: | run: |
"${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
- name: Setup Python 3.12 - name: Setup Python 3.12
uses: actions/setup-python@v6 uses: actions/setup-python@v6
......
...@@ -108,14 +108,11 @@ jobs: ...@@ -108,14 +108,11 @@ jobs:
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" } - { runner: macos-latest, toolkit: "Metal" }
python-version: python-version:
- "3.8" # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8. # Only build wheels against Python 3.8 Limited API to save CI resources.
# - "3.9" # FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# - "3.10" # Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
# - "3.11" - "3.9"
# - "3.12"
# - "3.13"
# - "3.14"
fail-fast: false fail-fast: false
timeout-minutes: 120 timeout-minutes: 120
runs-on: ${{ matrix.target.runner }} runs-on: ${{ matrix.target.runner }}
......
...@@ -12,6 +12,17 @@ concurrency: ...@@ -12,6 +12,17 @@ concurrency:
group: "${{ github.workflow }}-${{ github.ref }}" group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true # always cancel in-progress cancel-in-progress: true # always cancel in-progress
env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
jobs: jobs:
perfbench: perfbench:
name: Benchmark between PR and main name: Benchmark between PR and main
...@@ -31,7 +42,12 @@ jobs: ...@@ -31,7 +42,12 @@ jobs:
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.9" python-version: "3.12"
update-environment: true
cache: pip
cache-dependency-path: |
pyproject.toml
requirements*.txt
- name: Install merged version - name: Install merged version
run: | run: |
......
Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779 Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c
...@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17) ...@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}")
# Warning came from tvm submodule
string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference")
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git") if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git")
...@@ -36,9 +41,18 @@ endif() ...@@ -36,9 +41,18 @@ endif()
find_program(CCACHE_PROGRAM ccache) find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM) if(CCACHE_PROGRAM)
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
else()
find_program(SCCACHE_PROGRAM sccache)
if(SCCACHE_PROGRAM)
message(STATUS "Using sccache: ${SCCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
endif()
endif() endif()
# Configs # Configs
...@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS ...@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc src/target/utils.cc
src/target/codegen_cpp.cc src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency # intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc src/target/intrin_rule*.cc
) )
...@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper ...@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper
# let libtilelang to search tvm/tvm_runtime in same dir # let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE) if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path") set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path") set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
else() set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN") set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN") elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
endif() endif()
install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib) install(
TARGETS tvm tvm_runtime tilelang_module tilelang
# Copy tvm cython ext for wheels LIBRARY DESTINATION tilelang/lib
# TODO: not necessary for editable builds )
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
...@@ -11,8 +11,17 @@ endif() ...@@ -11,8 +11,17 @@ endif()
set(TVM_INCLUDES set(TVM_INCLUDES
${TVM_SOURCE}/include ${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src ${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include ${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include ${TVM_SOURCE}/3rdparty/dmlc-core/include
) )
if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()
if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
...@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi ...@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents ## Table of Contents
1. [Getting Started](#getting-started) - [Table of Contents](#table-of-contents)
2. [Simple GEMM Example](#simple-gemm-example) - [Getting Started](#getting-started)
- [Code Walkthrough](#code-walkthrough) - [Prerequisites](#prerequisites)
- [Compiling and Profiling](#compiling-and-profiling) - [Installation](#installation)
3. [Advanced GEMM Features](#advanced-gemm-features) - [Simple GEMM Example](#simple-gemm-example)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) - [Code Walkthrough](#code-walkthrough)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) - [Compiling and Profiling](#compiling-and-profiling)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) - [Advanced GEMM Features](#advanced-gemm-features)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
5. [Verifying Correctness](#verifying-correctness) - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations) - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
- [Example Workflow](#example-workflow) - [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Summary](#summary) - [Verifying Correctness](#verifying-correctness)
7. [References](#references) - [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
- [References](#references)
--- ---
...@@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi ...@@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi
### Prerequisites ### Prerequisites
- **Python 3.8+** - **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed - **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification) - **PyTorch** (optional, for easy correctness verification)
- **tilelang** - **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples) - **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation ### Installation
...@@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
### Code Walkthrough ### Code Walkthrough
1. **Define the Kernel Launch Configuration:** 1. **Define the Kernel Launch Configuration:**
```python ```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
``` ```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads. This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:** 2. **Shared Memory Allocation:**
```python ```python
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
``` ```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access. Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:** 3. **Local Fragment Accumulation:**
```python ```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
``` ```
Partial results are stored in registers (or local memory) to reduce writes to global memory. Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:** 4. **Pipelined Loading and GEMM:**
```python ```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...) T.copy(...)
...@@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
``` ```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation. Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:** 5. **Copy Out the Results:**
```python ```python
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
``` ```
...@@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
``` ```
**Key Differences vs. Basic Example** **Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). 1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. 2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. 3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
--- ---
...@@ -247,7 +250,7 @@ print("Results match!") ...@@ -247,7 +250,7 @@ print("Results match!")
## Fine-grained MMA Computations ## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow ### Example Workflow
...@@ -394,10 +397,10 @@ def tl_matmul( ...@@ -394,10 +397,10 @@ def tl_matmul(
] ]
``` ```
1. **Set Up Tile Sizes and Thread Bindings** 1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID). Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments** 2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like: Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python ```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
...@@ -406,7 +409,7 @@ def tl_matmul( ...@@ -406,7 +409,7 @@ def tl_matmul(
``` ```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles. Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`** 3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well: Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python ```python
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
...@@ -418,7 +421,7 @@ def tl_matmul( ...@@ -418,7 +421,7 @@ def tl_matmul(
``` ```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction** 4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially: After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[ \[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}} C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
...@@ -429,7 +432,7 @@ def tl_matmul( ...@@ -429,7 +432,7 @@ def tl_matmul(
``` ```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel. Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`** 5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python ```python
mma_emitter.stmatrix(C_local, C_shared) mma_emitter.stmatrix(C_local, C_shared)
...@@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma ...@@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma
## References ## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. - [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. - [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul. - [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
...@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then ...@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2 echo "Checking specified files: ${FILES[*]}..." >&2
fi fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it. # If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit python3 -m pip install pre-commit
......
...@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }] ...@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT" license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"] keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU", "Environment :: GPU",
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS", "Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
] ]
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"apache-tvm-ffi~=0.1.0",
"cloudpickle", "cloudpickle",
"ml-dtypes", "ml-dtypes",
"numpy>=1.23.5", "numpy>=1.23.5",
...@@ -39,11 +45,7 @@ dependencies = [ ...@@ -39,11 +45,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"] fp4 = ["ml-dtypes>=0.5.1"]
[build-system] [build-system]
requires = [ requires = ["cython>=3.0.0", "scikit-build-core"]
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[tool.scikit-build] [tool.scikit-build]
...@@ -180,27 +182,37 @@ build-frontend = "build" ...@@ -180,27 +182,37 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" } environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [ environment-pass = [
"CUDA_VERSION", "CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS", "COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR", "FORCE_COLOR",
"CLICOLOR_FORCE", "CLICOLOR_FORCE",
] ]
before-build = "env -0 | sort -z | tr '\\0' '\\n'" before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set" windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now test-command = [
manylinux-x86_64-image = "manylinux2014" "python -c 'import tilelang; print(tilelang.__version__)'",
manylinux-aarch64-image = "manylinux_2_28" ]
[tool.cibuildwheel.linux] [tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" } environment.PYTHONDEVMODE = "1"
repair-wheel-command = [ environment.PYTHONUNBUFFERED = "1"
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", environment.PATH = "/usr/local/cuda/bin:$PATH"
"pipx run abi3audit --strict --report {wheel}", environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
] # Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library # Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8 # manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """ before-all = """
set -eux set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in case "$(uname -m)" in
"x86_64") "x86_64")
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
...@@ -215,5 +227,22 @@ esac ...@@ -215,5 +227,22 @@ esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)" cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)"
v="${cudaver//./-}" v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
""" """
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
...@@ -18,10 +18,11 @@ cython ...@@ -18,10 +18,11 @@ cython
docutils docutils
dtlib dtlib
einops einops
flash-linear-attention==0.3.2
packaging>=21.0 packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations pytest-durations
pytest-timeout pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4 pytest>=6.2.4
pyyaml pyyaml
requests requests
......
# Runtime requirements # Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle cloudpickle
ml-dtypes ml-dtypes
numpy>=1.23.5 numpy>=1.23.5
...@@ -7,4 +8,3 @@ torch ...@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin' torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3 tqdm>=4.62.3
typing-extensions>=4.10.0 typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include "./transform/common/attr.h" #include "./transform/common/attr.h"
#include "op/builtin.h" #include "op/builtin.h"
#include "tvm/ffi/any.h" #include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name, dom->dtype); Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain. // Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var); n->vars.push_back(var);
n->doms.push_back(Range(0, dom)); n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents, ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) { const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size()); n->vars.reserve(extents.size());
n->doms.reserve(extents.size()); n->doms.reserve(extents.size());
for (const auto &extent : extents) { for (const auto &extent : extents) {
...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, ...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync, const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) { const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype(); DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype)); n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop)); n->doms.push_back(Range(std::move(start), stop));
...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size, ...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) { const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir; using namespace tvm::tir;
ICHECK(!domain.empty()); ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size()); n->vars.reserve(domain.size());
n->doms.reserve(domain.size()); n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0]; PrimExpr domain_size = domain[0];
...@@ -193,8 +196,8 @@ public: ...@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames); "frames", &KernelLaunchFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.KernelLaunchFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode); KernelLaunchFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -218,14 +221,20 @@ public: ...@@ -218,14 +221,20 @@ public:
*/ */
class KernelLaunchFrame : public TIRFrame { class KernelLaunchFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame, explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
KernelLaunchFrameNode); : TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
}; };
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt, const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) { const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>(); ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads. // If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame = bool is_cpu_kernel_frame =
...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, ...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n); return KernelLaunchFrame(n);
} }
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("tl.Parallel", ParallelFor) .def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor) .def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor) .def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch); .def("tl.KernelLaunch", KernelLaunch);
}); }
class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrameNode : public TIRFrameNode {
public: public:
...@@ -310,8 +317,8 @@ public: ...@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames); "frames", &WarpSpecializeFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.WarpSpecializeFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode); WarpSpecializeFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -330,15 +337,20 @@ public: ...@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame { class WarpSpecializeFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame, explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
TIRFrame, : TIRFrame(::tvm::ffi::UnsafeInit{}) {
WarpSpecializeFrameNode); ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode);
}; };
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx, const PrimExpr &thread_idx,
int warp_group_size = 128) { int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>(); ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition; PrimExpr condition;
std::vector<int> warp_groups; std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size()); warp_groups.reserve(warp_group_ids.size());
...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, ...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n); return WarpSpecializeFrame(n);
} }
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize); refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection(); KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection(); WarpSpecializeFrameNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) { ...@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
} }
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index); auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const { ...@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map( Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); }); [&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars // Concatenate with the remaining elements from vars
Array<PrimExpr> result; Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) { for (size_t i = 0; i < vars.size() - InputDim(); i++) {
...@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size); factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
} }
if (factor == 1) if (factor == 1)
return GetRef<Fragment>(this); return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor + vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
...@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const {
} }
Fragment FragmentNode::BindThreadRange(Range thread_range) const { Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this); auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range; n->thread_range_ = thread_range;
return Fragment(n); return Fragment(n);
} }
...@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, ...@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap); forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute( forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
} }
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const { ...@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str(); return ss.str();
} }
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const { bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
...@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() { ...@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_); .def_ro("replicate_size", &FragmentNode::replicate_size_);
} }
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def_packed("tl.Layout", .def_packed("tl.Layout",
...@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) { .def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous); return makeGemmLayoutLinear(stride, continuous);
}); });
}); }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection(); LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection(); FragmentNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -8,8 +8,11 @@ ...@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility> #include <utility>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -44,11 +47,10 @@ public: ...@@ -44,11 +47,10 @@ public:
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected: protected:
virtual Map<Var, Range> getVarMap() const; virtual Map<Var, Range> getVarMap() const;
...@@ -65,7 +67,7 @@ public: ...@@ -65,7 +67,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
}; };
class FragmentNode : public LayoutNode { class FragmentNode : public LayoutNode {
...@@ -109,9 +111,9 @@ public: ...@@ -109,9 +111,9 @@ public:
static void RegisterReflection(); static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr const char *_type_key = "tl.Fragment"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); kTVMFFISEqHashKindTreeNode;
protected: protected:
Map<Var, Range> getVarMap() const final; Map<Var, Range> getVarMap() const final;
...@@ -132,7 +134,7 @@ public: ...@@ -132,7 +134,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size, PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var); Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
}; };
Var InputPlaceholder(size_t idx); Var InputPlaceholder(size_t idx);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "swizzle.h" #include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, ...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() { ...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>(); refl::ObjectDef<SwizzledLayoutNode>();
} }
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -44,10 +44,9 @@ public: ...@@ -44,10 +44,9 @@ public:
Layout Inverse() const final; Layout Inverse() const final;
std::string DebugOutput() const final; std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private: private:
SwizzlePattern pattern_; SwizzlePattern pattern_;
...@@ -62,11 +61,11 @@ public: ...@@ -62,11 +61,11 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode); SwizzledLayoutNode);
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_LAYOUT_SWIZZLE_H_ #endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
...@@ -189,7 +189,7 @@ public: ...@@ -189,7 +189,7 @@ public:
IterMark Mutate(const IterMark &mark) { IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) { if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent); return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else { } else {
return mark; return mark;
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
...@@ -42,7 +42,7 @@ using namespace tir; ...@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_. * - The constructed node is stored in this->data_.
*/ */
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { ...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode. * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/ */
TileOperator AtomicAddNode::Clone() const { TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this); auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) { if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone()); op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
} }
...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment