Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
---
InheritParentConfig: true
ExtraArgs: ['-v']
ExtraArgs: []
FormatStyle: file
UseColor: true
WarningsAsErrors: '*'
......
......@@ -22,10 +22,12 @@ env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
UV_INDEX_STRATEGY: "unsafe-best-match"
UV_HTTP_TIMEOUT: "600"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated
......@@ -44,7 +46,7 @@ jobs:
submodules: recursive
- name: Setup Python 3.8
id: setup-py38
id: setup-pylowest
uses: actions/setup-python@v6
with:
python-version: "3.8" # use lowest supported version for linting
......@@ -52,7 +54,7 @@ jobs:
- name: Check AST with Python 3.8
run: |
"${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
- name: Setup Python 3.12
uses: actions/setup-python@v6
......
......@@ -108,14 +108,11 @@ jobs:
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: macos-latest, toolkit: "Metal" }
python-version:
- "3.8"
# TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8.
# - "3.9"
# - "3.10"
# - "3.11"
# - "3.12"
# - "3.13"
# - "3.14"
# Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8.
# Only build wheels against Python 3.8 Limited API to save CI resources.
# FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support
# Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9.
- "3.9"
fail-fast: false
timeout-minutes: 120
runs-on: ${{ matrix.target.runner }}
......
......@@ -12,6 +12,17 @@ concurrency:
group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true # always cancel in-progress
env:
PYTHONDEVMODE: "1"
PYTHONUNBUFFERED: "1"
PYTHONPATH: "" # explicit cleanup
PIP_USER: "" # explicit cleanup
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated
PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated
jobs:
perfbench:
name: Benchmark between PR and main
......@@ -31,7 +42,12 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.9"
python-version: "3.12"
update-environment: true
cache: pip
cache-dependency-path: |
pyproject.toml
requirements*.txt
- name: Install merged version
run: |
......
Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779
Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c
......@@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}")
# Warning came from tvm submodule
string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference")
endif()
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git")
......@@ -36,9 +41,18 @@ endif()
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Using ccache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
else()
find_program(SCCACHE_PROGRAM sccache)
if(SCCACHE_PROGRAM)
message(STATUS "Using sccache: ${SCCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher")
set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher")
endif()
endif()
# Configs
......@@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
# intrin_rule doesn't have system dependency
src/target/intrin_rule*.cc
)
......@@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper
# let libtilelang to search tvm/tvm_runtime in same dir
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path")
else()
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN")
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
endif()
install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib)
# Copy tvm cython ext for wheels
# TODO: not necessary for editable builds
if(TVM_BUILD_FROM_SOURCE)
add_dependencies(tilelang tvm_cython)
install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/)
endif()
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
)
......@@ -11,8 +11,17 @@ endif()
set(TVM_INCLUDES
${TVM_SOURCE}/include
${TVM_SOURCE}/ffi/include
${TVM_SOURCE}/src
${TVM_SOURCE}/3rdparty/dlpack/include
${TVM_SOURCE}/3rdparty/dmlc-core/include
)
if(EXISTS ${TVM_SOURCE}/ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include)
elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include)
endif()
if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include)
endif()
......@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
- [References](#references)
---
......
......@@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2
fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
......
......@@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU",
"Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version"]
dependencies = [
"apache-tvm-ffi~=0.1.0",
"cloudpickle",
"ml-dtypes",
"numpy>=1.23.5",
......@@ -39,11 +45,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"]
[build-system]
requires = [
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
requires = ["cython>=3.0.0", "scikit-build-core"]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
......@@ -180,27 +182,37 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [
"CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR",
"CLICOLOR_FORCE",
]
before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014"
manylinux-aarch64-image = "manylinux_2_28"
test-command = [
"python -c 'import tilelang; print(tilelang.__version__)'",
]
[tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" }
repair-wheel-command = [
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --strict --report {wheel}",
]
environment.PYTHONDEVMODE = "1"
environment.PYTHONUNBUFFERED = "1"
environment.PATH = "/usr/local/cuda/bin:$PATH"
environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """
set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in
"x86_64")
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
......@@ -215,5 +227,22 @@ esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)"
v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
"""
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
......@@ -18,10 +18,11 @@ cython
docutils
dtlib
einops
flash-linear-attention==0.3.2
packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations
pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4
pyyaml
requests
......
# Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle
ml-dtypes
numpy>=1.23.5
......@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
......@@ -7,6 +7,9 @@
#include "./transform/common/attr.h"
#include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h>
......@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir;
Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
......@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto &extent : extents) {
......@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
......@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir;
ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size());
n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0];
......@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames);
}
static constexpr const char *_type_key = "tl.KernelLaunchFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
KernelLaunchFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -218,14 +221,20 @@ public:
*/
class KernelLaunchFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
};
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
......@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n);
}
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
}
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
......@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames);
}
static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
WarpSpecializeFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
TIRFrame,
explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode);
};
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx,
int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition;
std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size());
......@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n);
}
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -64,13 +64,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
}
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
......@@ -130,7 +129,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars
Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
......@@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1)
return GetRef<Fragment>(this);
return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
......@@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const {
}
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}
......@@ -336,8 +334,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -348,8 +346,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str();
}
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
......@@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.Layout",
......@@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous);
});
});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......@@ -44,11 +47,10 @@ public:
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
virtual Map<Var, Range> getVarMap() const;
......@@ -65,7 +67,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
};
class FragmentNode : public LayoutNode {
......@@ -109,9 +111,9 @@ public:
static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
Map<Var, Range> getVarMap() const final;
......@@ -132,7 +134,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};
Var InputPlaceholder(size_t idx);
......
......@@ -6,6 +6,7 @@
#include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
......@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>();
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl
} // namespace tvm
......@@ -44,10 +44,9 @@ public:
Layout Inverse() const final;
std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private:
SwizzlePattern pattern_;
......@@ -62,8 +61,8 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
SwizzledLayoutNode);
};
} // namespace tl
......
......@@ -189,7 +189,7 @@ public:
IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else {
return mark;
}
......
......@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......
......@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
......@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment