Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
---
InheritParentConfig: true
ExtraArgs: ['-v']
ExtraArgs: []
FormatStyle: file
UseColor: true
WarningsAsErrors: '*'
......
......@@ -22,10 +22,12 @@ env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
UV_INDEX_STRATEGY: "unsafe-best-match"
UV_HTTP_TIMEOUT: "600"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated
......@@ -44,7 +46,7 @@ jobs:
submodules: recursive
- name: Setup Python 3.8
id: setup-py38
id: setup-pylowest
uses: actions/setup-python@v6
with:
python-version: "3.8" # use lowest supported version for linting
......@@ -52,7 +54,7 @@ jobs:
- name: Check AST with Python 3.8
run: |
"${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
- name: Setup Python 3.12
uses: actions/setup-python@v6
......
......@@ -108,14 +108,11 @@ jobs:
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" }
python-version:
- "3.8"
# TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8.
# - "3.9"
# - "3.10"
# - "3.11"
# - "3.12"
# - "3.13"
# - "3.14"
# Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# Only build wheels against Python 3.8 Limited API to save CI resources.
# FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
- "3.9"
fail-fast: false
timeout-minutes: 120
runs-on: ${{ matrix.target.runner }}
......
......@@ -12,6 +12,17 @@ concurrency:
group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true # always cancel in-progress
env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
jobs:
perfbench:
name: Benchmark between PR and main
......@@ -31,7 +42,12 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.9"
python-version: "3.12"
update-environment: true
cache: pip
cache-dependency-path: |
pyproject.toml
requirements*.txt
- name: Install merged version
run: |
......
Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779
Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c
......@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}")
# Warning came from tvm submodule
string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference")
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git")
......@@ -36,9 +41,18 @@ endif()
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
else()
find_program(SCCACHE_PROGRAM sccache)
if(SCCACHE_PROGRAM)
message(STATUS "Using sccache: ${SCCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
endif()
endif()
# Configs
......@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)
......@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper
# let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path")
else()
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
endif()
install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib)
# Copy tvm cython ext for wheels
# TODO: not necessary for editable builds
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
)
......@@ -11,8 +11,17 @@ endif()
set(TVM_INCLUDES
${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include
)
if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()
if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
......@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
- [References](#references)
---
......@@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi
### Prerequisites
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation
......@@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
### Code Walkthrough
1. **Define the Kernel Launch Configuration:**
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:**
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:**
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:**
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
......@@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:**
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
......@@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
```
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
---
......@@ -247,7 +250,7 @@ print("Results match!")
## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow
......@@ -394,10 +397,10 @@ def tl_matmul(
]
```
1. **Set Up Tile Sizes and Thread Bindings**
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments**
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
......@@ -406,7 +409,7 @@ def tl_matmul(
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`**
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
......@@ -418,7 +421,7 @@ def tl_matmul(
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction**
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
......@@ -429,7 +432,7 @@ def tl_matmul(
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`**
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared)
......@@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma
## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
......@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2
fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
......
......@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU",
"Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version"]
dependencies = [
"apache-tvm-ffi~=0.1.0",
"cloudpickle",
"ml-dtypes",
"numpy>=1.23.5",
......@@ -39,11 +45,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"]
[build-system]
requires = [
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
requires = ["cython>=3.0.0", "scikit-build-core"]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
......@@ -180,27 +182,37 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [
"CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR",
"CLICOLOR_FORCE",
]
before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014"
manylinux-aarch64-image = "manylinux_2_28"
test-command = [
"python -c 'import tilelang; print(tilelang.__version__)'",
]
[tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" }
repair-wheel-command = [
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --strict --report {wheel}",
]
environment.PYTHONDEVMODE = "1"
environment.PYTHONUNBUFFERED = "1"
environment.PATH = "/usr/local/cuda/bin:$PATH"
environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """
set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in
"x86_64")
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
......@@ -215,5 +227,22 @@ esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)"
v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
"""
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
......@@ -18,10 +18,11 @@ cython
docutils
dtlib
einops
flash-linear-attention==0.3.2
packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations
pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4
pyyaml
requests
......
# Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle
ml-dtypes
numpy>=1.23.5
......@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
......@@ -7,6 +7,9 @@
#include "./transform/common/attr.h"
#include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h>
......@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir;
Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
......@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto &extent : extents) {
......@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
......@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir;
ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size());
n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0];
......@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames);
}
static constexpr const char *_type_key = "tl.KernelLaunchFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
KernelLaunchFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -218,14 +221,20 @@ public:
*/
class KernelLaunchFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
};
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
......@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n);
}
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
}
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
......@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames);
}
static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
WarpSpecializeFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
TIRFrame,
WarpSpecializeFrameNode);
explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode);
};
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx,
int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition;
std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size());
......@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n);
}
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
}
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
......@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars
Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
......@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1)
return GetRef<Fragment>(this);
return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
......@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const {
}
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}
......@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str();
}
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
......@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.Layout",
......@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous);
});
});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......@@ -44,11 +47,10 @@ public:
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
virtual Map<Var, Range> getVarMap() const;
......@@ -65,7 +67,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
};
class FragmentNode : public LayoutNode {
......@@ -109,9 +111,9 @@ public:
static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
Map<Var, Range> getVarMap() const final;
......@@ -132,7 +134,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};
Var InputPlaceholder(size_t idx);
......
......@@ -6,6 +6,7 @@
#include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
......@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>();
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -44,10 +44,9 @@ public:
Layout Inverse() const final;
std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private:
SwizzlePattern pattern_;
......@@ -62,11 +61,11 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
SwizzledLayoutNode);
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
......@@ -189,7 +189,7 @@ public:
IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else {
return mark;
}
......
......@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......
......@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
......@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment