Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
--- ---
InheritParentConfig: true InheritParentConfig: true
ExtraArgs: ['-v'] ExtraArgs: []
FormatStyle: file FormatStyle: file
UseColor: true UseColor: true
WarningsAsErrors: '*' WarningsAsErrors: '*'
......
...@@ -22,10 +22,12 @@ env: ...@@ -22,10 +22,12 @@ env:
PYTHONDEVMODE: "1" PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1" PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100" COLUMNS: "100"
FORCE_COLOR: "1" FORCE_COLOR: "1"
CLICOLOR_FORCE: "1" CLICOLOR_FORCE: "1"
UV_INDEX_STRATEGY: "unsafe-best-match" UV_INDEX_STRATEGY: "unsafe-best-match"
UV_HTTP_TIMEOUT: "600"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated
...@@ -44,7 +46,7 @@ jobs: ...@@ -44,7 +46,7 @@ jobs:
submodules: recursive submodules: recursive
- name: Setup Python 3.8 - name: Setup Python 3.8
id: setup-py38 id: setup-pylowest
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.8" # use lowest supported version for linting python-version: "3.8" # use lowest supported version for linting
...@@ -52,7 +54,7 @@ jobs: ...@@ -52,7 +54,7 @@ jobs:
- name: Check AST with Python 3.8 - name: Check AST with Python 3.8
run: | run: |
"${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
- name: Setup Python 3.12 - name: Setup Python 3.12
uses: actions/setup-python@v6 uses: actions/setup-python@v6
......
...@@ -108,14 +108,11 @@ jobs: ...@@ -108,14 +108,11 @@ jobs:
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" } - { runner: macos-latest, toolkit: "Metal" }
python-version: python-version:
- "3.8" # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8. # Only build wheels against Python 3.8 Limited API to save CI resources.
# - "3.9" # FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# - "3.10" # Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
# - "3.11" - "3.9"
# - "3.12"
# - "3.13"
# - "3.14"
fail-fast: false fail-fast: false
timeout-minutes: 120 timeout-minutes: 120
runs-on: ${{ matrix.target.runner }} runs-on: ${{ matrix.target.runner }}
......
...@@ -12,6 +12,17 @@ concurrency: ...@@ -12,6 +12,17 @@ concurrency:
group: "${{ github.workflow }}-${{ github.ref }}" group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true # always cancel in-progress cancel-in-progress: true # always cancel in-progress
env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
jobs: jobs:
perfbench: perfbench:
name: Benchmark between PR and main name: Benchmark between PR and main
...@@ -31,7 +42,12 @@ jobs: ...@@ -31,7 +42,12 @@ jobs:
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.9" python-version: "3.12"
update-environment: true
cache: pip
cache-dependency-path: |
pyproject.toml
requirements*.txt
- name: Install merged version - name: Install merged version
run: | run: |
......
Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779 Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c
...@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17) ...@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}")
# Warning came from tvm submodule
string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference")
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git") if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git")
...@@ -36,9 +41,18 @@ endif() ...@@ -36,9 +41,18 @@ endif()
find_program(CCACHE_PROGRAM ccache) find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM) if(CCACHE_PROGRAM)
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
else()
find_program(SCCACHE_PROGRAM sccache)
if(SCCACHE_PROGRAM)
message(STATUS "Using sccache: ${SCCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
endif()
endif() endif()
# Configs # Configs
...@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS ...@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc src/target/utils.cc
src/target/codegen_cpp.cc src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency # intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc src/target/intrin_rule*.cc
) )
...@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper ...@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper
# let libtilelang to search tvm/tvm_runtime in same dir # let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE) if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path") set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path") set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
else() set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN") set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN") elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
endif() endif()
install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib) install(
TARGETS tvm tvm_runtime tilelang_module tilelang
# Copy tvm cython ext for wheels LIBRARY DESTINATION tilelang/lib
# TODO: not necessary for editable builds )
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
...@@ -11,8 +11,17 @@ endif() ...@@ -11,8 +11,17 @@ endif()
set(TVM_INCLUDES set(TVM_INCLUDES
${TVM_SOURCE}/include ${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src ${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include ${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include ${TVM_SOURCE}/3rdparty/dmlc-core/include
) )
if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()
if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
...@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi ...@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents ## Table of Contents
1. [Getting Started](#getting-started) - [Table of Contents](#table-of-contents)
2. [Simple GEMM Example](#simple-gemm-example) - [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough) - [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling) - [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features) - [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) - [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness) - [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations) - [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow) - [Example Workflow](#example-workflow)
- [Summary](#summary) - [Summary](#summary)
7. [References](#references) - [References](#references)
--- ---
......
...@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then ...@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2 echo "Checking specified files: ${FILES[*]}..." >&2
fi fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it. # If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit python3 -m pip install pre-commit
......
...@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }] ...@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT" license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"] keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU", "Environment :: GPU",
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS", "Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
] ]
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"apache-tvm-ffi~=0.1.0",
"cloudpickle", "cloudpickle",
"ml-dtypes", "ml-dtypes",
"numpy>=1.23.5", "numpy>=1.23.5",
...@@ -39,11 +45,7 @@ dependencies = [ ...@@ -39,11 +45,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"] fp4 = ["ml-dtypes>=0.5.1"]
[build-system] [build-system]
requires = [ requires = ["cython>=3.0.0", "scikit-build-core"]
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[tool.scikit-build] [tool.scikit-build]
...@@ -180,27 +182,37 @@ build-frontend = "build" ...@@ -180,27 +182,37 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" } environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [ environment-pass = [
"CUDA_VERSION", "CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS", "COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR", "FORCE_COLOR",
"CLICOLOR_FORCE", "CLICOLOR_FORCE",
] ]
before-build = "env -0 | sort -z | tr '\\0' '\\n'" before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set" windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now test-command = [
manylinux-x86_64-image = "manylinux2014" "python -c 'import tilelang; print(tilelang.__version__)'",
manylinux-aarch64-image = "manylinux_2_28" ]
[tool.cibuildwheel.linux] [tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" } environment.PYTHONDEVMODE = "1"
repair-wheel-command = [ environment.PYTHONUNBUFFERED = "1"
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", environment.PATH = "/usr/local/cuda/bin:$PATH"
"pipx run abi3audit --strict --report {wheel}", environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
] # Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library # Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8 # manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """ before-all = """
set -eux set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in case "$(uname -m)" in
"x86_64") "x86_64")
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
...@@ -215,5 +227,22 @@ esac ...@@ -215,5 +227,22 @@ esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)" cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)"
v="${cudaver//./-}" v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
""" """
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
...@@ -18,10 +18,11 @@ cython ...@@ -18,10 +18,11 @@ cython
docutils docutils
dtlib dtlib
einops einops
flash-linear-attention==0.3.2
packaging>=21.0 packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations pytest-durations
pytest-timeout pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4 pytest>=6.2.4
pyyaml pyyaml
requests requests
......
# Runtime requirements # Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle cloudpickle
ml-dtypes ml-dtypes
numpy>=1.23.5 numpy>=1.23.5
...@@ -7,4 +8,3 @@ torch ...@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin' torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3 tqdm>=4.62.3
typing-extensions>=4.10.0 typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include "./transform/common/attr.h" #include "./transform/common/attr.h"
#include "op/builtin.h" #include "op/builtin.h"
#include "tvm/ffi/any.h" #include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name, dom->dtype); Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain. // Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var); n->vars.push_back(var);
n->doms.push_back(Range(0, dom)); n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents, ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) { const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size()); n->vars.reserve(extents.size());
n->doms.reserve(extents.size()); n->doms.reserve(extents.size());
for (const auto &extent : extents) { for (const auto &extent : extents) {
...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, ...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync, const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) { const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype(); DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype)); n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop)); n->doms.push_back(Range(std::move(start), stop));
...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size, ...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) { const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir; using namespace tvm::tir;
ICHECK(!domain.empty()); ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size()); n->vars.reserve(domain.size());
n->doms.reserve(domain.size()); n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0]; PrimExpr domain_size = domain[0];
...@@ -193,8 +196,8 @@ public: ...@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames); "frames", &KernelLaunchFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.KernelLaunchFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode); KernelLaunchFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -218,14 +221,20 @@ public: ...@@ -218,14 +221,20 @@ public:
*/ */
class KernelLaunchFrame : public TIRFrame { class KernelLaunchFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame, explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode); KernelLaunchFrameNode);
}; };
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt, const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) { const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>(); ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads. // If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame = bool is_cpu_kernel_frame =
...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, ...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n); return KernelLaunchFrame(n);
} }
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("tl.Parallel", ParallelFor) .def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor) .def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor) .def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch); .def("tl.KernelLaunch", KernelLaunch);
}); }
class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrameNode : public TIRFrameNode {
public: public:
...@@ -310,8 +317,8 @@ public: ...@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames); "frames", &WarpSpecializeFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.WarpSpecializeFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode); WarpSpecializeFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -330,15 +337,20 @@ public: ...@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame { class WarpSpecializeFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame, explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
TIRFrame, : TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode); WarpSpecializeFrameNode);
}; };
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx, const PrimExpr &thread_idx,
int warp_group_size = 128) { int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>(); ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition; PrimExpr condition;
std::vector<int> warp_groups; std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size()); warp_groups.reserve(warp_group_ids.size());
...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, ...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n); return WarpSpecializeFrame(n);
} }
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize); refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection(); KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection(); WarpSpecializeFrameNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) { ...@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
} }
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index); auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const { ...@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map( Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); }); [&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars // Concatenate with the remaining elements from vars
Array<PrimExpr> result; Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) { for (size_t i = 0; i < vars.size() - InputDim(); i++) {
...@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size); factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
} }
if (factor == 1) if (factor == 1)
return GetRef<Fragment>(this); return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor + vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
...@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const {
} }
Fragment FragmentNode::BindThreadRange(Range thread_range) const { Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this); auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range; n->thread_range_ = thread_range;
return Fragment(n); return Fragment(n);
} }
...@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, ...@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap); forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute( forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
} }
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const { ...@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str(); return ss.str();
} }
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const { bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
...@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() { ...@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_); .def_ro("replicate_size", &FragmentNode::replicate_size_);
} }
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def_packed("tl.Layout", .def_packed("tl.Layout",
...@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) { .def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous); return makeGemmLayoutLinear(stride, continuous);
}); });
}); }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection(); LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection(); FragmentNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -8,8 +8,11 @@ ...@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility> #include <utility>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -44,11 +47,10 @@ public: ...@@ -44,11 +47,10 @@ public:
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected: protected:
virtual Map<Var, Range> getVarMap() const; virtual Map<Var, Range> getVarMap() const;
...@@ -65,7 +67,7 @@ public: ...@@ -65,7 +67,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
}; };
class FragmentNode : public LayoutNode { class FragmentNode : public LayoutNode {
...@@ -109,9 +111,9 @@ public: ...@@ -109,9 +111,9 @@ public:
static void RegisterReflection(); static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr const char *_type_key = "tl.Fragment"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); kTVMFFISEqHashKindTreeNode;
protected: protected:
Map<Var, Range> getVarMap() const final; Map<Var, Range> getVarMap() const final;
...@@ -132,7 +134,7 @@ public: ...@@ -132,7 +134,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size, PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var); Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
}; };
Var InputPlaceholder(size_t idx); Var InputPlaceholder(size_t idx);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "swizzle.h" #include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, ...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() { ...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>(); refl::ObjectDef<SwizzledLayoutNode>();
} }
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -44,10 +44,9 @@ public: ...@@ -44,10 +44,9 @@ public:
Layout Inverse() const final; Layout Inverse() const final;
std::string DebugOutput() const final; std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private: private:
SwizzlePattern pattern_; SwizzlePattern pattern_;
...@@ -62,8 +61,8 @@ public: ...@@ -62,8 +61,8 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode); SwizzledLayoutNode);
}; };
} // namespace tl } // namespace tl
......
...@@ -189,7 +189,7 @@ public: ...@@ -189,7 +189,7 @@ public:
IterMark Mutate(const IterMark &mark) { IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) { if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent); return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else { } else {
return mark; return mark;
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
...@@ -42,7 +42,7 @@ using namespace tir; ...@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_. * - The constructed node is stored in this->data_.
*/ */
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { ...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode. * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/ */
TileOperator AtomicAddNode::Clone() const { TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this); auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) { if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone()); op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
} }
...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment