Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
set -eux
#!/usr/bin/env bash
set -euxo pipefail
# Get the CUDA version from the command line
IMAGE="tilelang-builder:manylinux"
docker build . -f "$(dirname "${BASH_SOURCE[0]}")/pypi.manylinux.Dockerfile" --tag ${IMAGE}
script="sh maint/scripts/local_distribution.sh"
docker run --rm -v $(pwd):/tilelang ${IMAGE} /bin/bash -c "$script"
# Build for local architecture
CIBW_BUILD='cp38-*' cibuildwheel .
set -eux
#!/usr/bin/env bash
set -euxo pipefail
# Get the CUDA version from the command line
IMAGE="tilelang-builder:manylinux"
docker build . -f "$(dirname "${BASH_SOURCE[0]}")/pypi.manylinux.Dockerfile" --tag ${IMAGE}
if docker buildx version >/dev/null 2>&1; then
if docker info >/dev/null 2>&1; then
docker run --rm --privileged tonistiigi/binfmt --install amd64,arm64 >/dev/null 2>&1 || true
fi
script="sh maint/scripts/pypi_distribution.sh"
if ! docker buildx inspect multi >/dev/null 2>&1; then
docker buildx create --name multi --driver docker-container --use >/dev/null 2>&1 || true
else
docker buildx use multi >/dev/null 2>&1 || true
fi
docker buildx inspect --bootstrap >/dev/null 2>&1 || true
done
docker run --rm -v $(pwd):/tilelang -w /tilelang ${IMAGE} /bin/bash -c "$script"
export CIBW_ARCHS='x86_64 aarch64'
fi
NO_VERSION_LABEL=ON CIBW_BUILD='cp38-*' cibuildwheel .
FROM pytorch/manylinux2_28-builder:cuda12.1 AS builder_amd64
ENV CUDA_VERSION=12.1 \
AUDITWHEEL_PLAT=manylinux_2_28_x86_64
RUN pip3 install uv
FROM quay.io/pypa/manylinux2014_x86_64 AS builder_amd64
FROM pytorch/manylinuxaarch64-builder:cuda12.8 AS builder_arm64
ENV CUDA_VERSION=12.8 \
AUDITWHEEL_PLAT=manylinux_2_28_aarch64
RUN yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
ARG CUDA_VERSION=12.1
ENV CUDA_VERSION=${CUDA_VERSION}
FROM quay.io/pypa/manylinux_2_28_aarch64 AS builder_arm64
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ARG CUDA_VERSION=12.8
ENV CUDA_VERSION=${CUDA_VERSION}
ARG TARGETARCH
FROM builder_${TARGETARCH}
ENV DEBIAN_FRONTEND=noninteractive \
TZ=Etc/UTC
RUN set -eux; \
uv venv -p 3.12 --seed /venv; \
git config --global --add safe.directory '/tilelang'
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV PATH="/venv/bin:$PATH" \
VIRTUAL_ENV=/venv
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
RUN uv pip install build wheel
RUN set -eux; \
pipx install cibuildwheel; \
git config --global --add safe.directory '/tilelang'
WORKDIR /tilelang
set -eux
rm -rf dist
rm -rf dist raw_dist
python -mpip install -U pip
python -mpip install -U build wheel auditwheel patchelf
......
......@@ -2,27 +2,32 @@
name = "tilelang"
description = "A tile level programming language to generate high performance code."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU",
"Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS",
"Programming Language :: Python :: 3.8",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version"]
dependencies = [
"apache-tvm-ffi==0.1.0",
"cloudpickle",
"ml-dtypes",
"numpy>=1.23.5",
......@@ -39,11 +44,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"]
[build-system]
requires = [
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
requires = ["cython>=3.0.0", "scikit-build-core"]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
......@@ -58,11 +59,14 @@ metadata.version.provider = "version_provider"
metadata.version.provider-path = "."
experimental = true
# build.verbose = true
# logging.level = "DEBUG"
[tool.scikit-build.sdist]
# See MANIFEST.in for details
include = [
"VERSION",
"LICENSE",
"./VERSION",
".git_commit.txt",
"./LICENSE",
"THIRDPARTYNOTICES.txt",
"version_provider.py",
"requirements*.txt",
......@@ -70,7 +74,15 @@ include = [
"CMakeLists.txt",
"src/**",
"cmake/**",
"3rdparty/**",
# The vendored 3rdparty contents in sdist should be same as wheel.
# Need full TVM to build from source.
"3rdparty/tvm",
# CUTLASS
"3rdparty/cutlass/include",
"3rdparty/cutlass/tools",
# Composable Kernel
"3rdparty/composable_kernel/include",
"3rdparty/composable_kernel/library",
"testing/**",
"examples/**",
]
......@@ -79,8 +91,7 @@ exclude = [
".github",
"**/.git",
"**/.github",
"3rdparty/clang**",
"3rdparty/llvm**",
"3rdparty/**",
"build",
]
......@@ -89,7 +100,17 @@ tilelang = "tilelang"
"tilelang/src" = "src"
# NOTE: The mapping below places the contents of '3rdparty' inside 'tilelang/3rdparty' in the wheel.
# This is necessary to find TVM shared libraries at runtime.
"tilelang/3rdparty" = "3rdparty"
# The vendored 3rdparty contents in wheel should be same as sdist.
# TVM
"tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src"
"tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python"
"tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py"
# CUTLASS
"tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include"
"tilelang/3rdparty/cutlass/tools" = "3rdparty/cutlass/tools"
# Composable Kernel
"tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library"
[tool.yapf]
based_on_style = "yapf"
......@@ -106,7 +127,7 @@ skip = [
]
[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 100
output-format = "full"
......@@ -170,32 +191,45 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [
"CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR",
"CLICOLOR_FORCE",
]
before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
manylinux-x86_64-image = "manylinux2014"
manylinux-aarch64-image = "manylinux_2_28"
test-command = [
"python -c 'import tilelang; print(tilelang.__version__)'",
]
[tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" }
repair-wheel-command = [
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --strict --report {wheel}",
]
environment.PYTHONDEVMODE = "1"
environment.PYTHONUNBUFFERED = "1"
environment.PATH = "/usr/local/cuda/bin:$PATH"
environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
# TODO: upgrade to manylinux_2_28 at some time
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """
set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in
"x86_64")
DEFAULT_CUDA_VERSION="12.1"
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
;;
"aarch64")
DEFAULT_CUDA_VERSION="12.8"
dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
;;
*)
......@@ -203,7 +237,24 @@ case "$(uname -m)" in
;;
esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)"
cudaver="$(echo "${CUDA_VERSION:-$DEFAULT_CUDA_VERSION}" | cut -d '.' -f-2)"
v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
"""
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
# Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi~=0.1.0
build
cmake>=3.26
cython>=3.0.0
......
......@@ -3,5 +3,5 @@ pre-commit
clang-format==21.1.2
clang-tidy==21.1.1
codespell[toml]==2.4.1
ruff==0.14.1
ruff==0.14.3
yapf==0.43.0
......@@ -18,10 +18,11 @@ cython
docutils
dtlib
einops
flash-linear-attention==0.3.2
packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations
pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4
pyyaml
requests
......
# Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle
ml-dtypes
numpy>=1.23.5
......@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
......@@ -7,6 +7,9 @@
#include "./transform/common/attr.h"
#include "op/builtin.h"
#include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h>
......@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir;
Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
......@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto &extent : extents) {
......@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
......@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir;
ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size());
n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0];
......@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames);
}
static constexpr const char *_type_key = "tl.KernelLaunchFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
KernelLaunchFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -218,14 +221,20 @@ public:
*/
class KernelLaunchFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
};
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
......@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n);
}
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch);
});
}
class WarpSpecializeFrameNode : public TIRFrameNode {
public:
......@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames);
}
static constexpr const char *_type_key = "tl.WarpSpecializeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
WarpSpecializeFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
......@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame,
TIRFrame,
WarpSpecializeFrameNode);
explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode);
};
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx,
int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>();
ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition;
std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size());
......@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n);
}
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -594,11 +594,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
bool k_inner) {
if (k_inner)
if (k_inner && continuous % 32 == 0 && stride % 32 == 0)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
if (is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0)
if (!is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16);
}
......
......@@ -5,6 +5,7 @@
#include "layout.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
......@@ -64,13 +65,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
}
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index);
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
......@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
for (size_t i = 0; i < ret.size(); i++) {
auto ist = analyzer.int_set(forward_index_[i] + 1);
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression
ret.Set(i, input_size_[i]);
// Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
// Fall back to ConstIntBound to derive a safe extent.
auto cib = analyzer.const_int_bound(forward_index_[i]);
if (cib->min_value != arith::ConstIntBound::kNegInf &&
cib->max_value != arith::ConstIntBound::kPosInf &&
cib->min_value >= 0) {
// extent = max - min + 1, using 64-bit integer literal
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
} else {
// Last-resort conservative fallback to avoid OOB/crash
// Prefer to keep dimension from known input_size_ if available.
if (i < input_size_.size()) {
ret.Set(i, input_size_[i]);
} else {
ret.Set(i, Integer(1));
}
}
} else {
// CHECK(is_one(ist.min())) << ist.min();
ret.Set(i, ist.max());
}
}
......@@ -130,7 +144,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars
Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) {
......@@ -212,7 +225,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1)
return GetRef<Fragment>(this);
return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
......@@ -224,7 +237,7 @@ Fragment FragmentNode::DeReplicate() const {
}
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}
......@@ -251,14 +264,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
}
arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
ICHECK(res->errors.empty())
<< "Layout " << DebugOutput() << " has errors: " << res->errors;
if (!res->errors.empty()) {
std::ostringstream msg;
msg << "Layout " << DebugOutput() << " has errors: " << res->errors;
throw NormalizeIterException(msg.str());
}
auto outputs_shape = OutputShape();
Array<PrimExpr> outputs;
......@@ -280,10 +296,170 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
return {Layout(outputs_shape, backward_index), level};
}
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}
// Step 1. Prove the product of InputShape is equal to the product of shape
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
}
PrimExpr shape_product = Integer(1);
for (const auto &dim : shape) {
shape_product *= dim;
}
// Use provided analyzer if present, otherwise a local fallback to avoid
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
// Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
PrimExpr flat_index = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j) {
stride = stride * shape[j];
}
flat_index = flat_index + new_vars[i] * stride;
}
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
stride = stride * InputShape()[j];
}
original_indices.push_back(floordiv(remaining, stride));
remaining = floormod(remaining, stride);
}
// Step 5. Substitute original indices into forward_index_
Array<PrimExpr> new_forward_index;
for (const auto &fwd_expr : forward_index_) {
PrimExpr substituted = fwd_expr;
// Replace each InputPlaceholder(i) with original_indices[i]
for (size_t i = 0; i < InputShape().size(); ++i) {
substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
}
new_forward_index.push_back(az->Simplify(substituted));
}
for (size_t i = 0; i < new_vars.size(); ++i) {
new_forward_index =
Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}});
}
return Layout(shape, new_forward_index);
}
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
}
// 1) Prove total number of elements remains the same
PrimExpr input_prod = Integer(1);
for (const auto &d : InputShape())
input_prod *= d;
PrimExpr shape_prod = Integer(1);
for (const auto &d : shape)
shape_prod *= d;
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
// Cannot use InputPlaceholder(i) here, because it would cause name capture
// (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
// we must create a fresh variable here to avoid confusion when
// substituting.
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j)
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
stride = stride * InputShape()[j];
orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride);
}
// 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) {
PrimExpr cur = e;
for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
}
cur = az->Simplify(cur);
new_forward_index.push_back(cur);
}
PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}});
}
new_forward_thread = az->Simplify(new_forward_thread);
for (size_t i = 0; i < new_vars.size(); ++i) {
auto var = new_vars[i];
new_forward_index =
Substitute(new_forward_index, {{var, InputPlaceholder(i)}});
new_forward_thread =
Substitute(new_forward_thread, {{var, InputPlaceholder(i)}});
}
Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) {
reshaped = reshaped->BindThreadRange(thread_range_);
}
return reshaped;
}
Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first);
}
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread,
arith::Analyzer *analyzer) {
......@@ -336,8 +512,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -348,8 +524,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
replicate_size);
auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
forward_thread, replicate_size);
data_ = std::move(n);
}
......@@ -442,21 +618,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str();
}
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
......@@ -495,10 +656,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tl.Layout",
......@@ -560,12 +718,23 @@ TVM_FFI_STATIC_INIT_BLOCK({
element_size, k_inner);
}
})
.def("tl.make_volta_swizzled_layout",
[](int stride, int mat_continuous, bool is_a, bool k_inner) {
return makeGemmVoltaABLayout(stride, mat_continuous, is_a,
k_inner);
})
.def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
......@@ -582,13 +751,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous);
});
});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......@@ -38,17 +41,20 @@ public:
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const;
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
virtual std::string DebugOutput() const;
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
virtual Map<Var, Range> getVarMap() const;
......@@ -65,7 +71,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
};
class FragmentNode : public LayoutNode {
......@@ -79,6 +85,9 @@ public:
Array<PrimExpr> GetForwardVars() const final;
Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
PrimExpr ThreadExtent() const;
......@@ -109,9 +118,9 @@ public:
static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected:
Map<Var, Range> getVarMap() const final;
......@@ -132,7 +141,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};
Var InputPlaceholder(size_t idx);
......
......@@ -6,6 +6,7 @@
#include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n);
}
......@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>();
}
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -44,10 +44,9 @@ public:
Layout Inverse() const final;
std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private:
SwizzlePattern pattern_;
......@@ -62,11 +61,11 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
SwizzledLayoutNode);
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
......@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
return results;
}
// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator
// appears in fused forms across multiple index expressions. We first normalize
// every index into IterSumExpr, collect all splits per source Var, then
// consolidate them to avoid misclassifying a used split as unused.
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters,
Analyzer *analyzer) {
......@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
}
for (const IterVar &iter : input_iters) {
IterMark iv_mark;
// Merge splits from all IterMark that share the same source Var as `iter`.
std::vector<IterSplitExpr> merged_splits;
for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>()->same_as(iter->var)) { // NOLINT(*)
iv_mark = mark;
break;
auto vexpr = mark->source.as<Var>();
if (vexpr && vexpr.value().same_as(iter->var)) {
auto it = collector.mark2splits_.find(mark);
if (it != collector.mark2splits_.end()) {
const auto &vec = it->second;
merged_splits.insert(merged_splits.end(), vec.begin(), vec.end());
}
}
}
if (iv_mark.defined()) {
auto splits =
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
// Put the small axis last
if (!merged_splits.empty()) {
// Use a unified mark (Var + full extent) to compute the missing pieces
// so that fused usages are honored as "used" and not reintroduced.
IterMark unified_mark(iter->var, iter->dom->extent);
auto splits = get_unused_iters(unified_mark, merged_splits, analyzer);
// Put the small axis last for a flattened ordering.
results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) {
auto mark = IterMark(iter->var, iter->dom->extent);
......@@ -189,7 +201,7 @@ public:
IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else {
return mark;
}
......
......@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
......
......@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
......@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -25,8 +25,8 @@ public:
IntImm memory_order; ///< Memory order for atomic operations
mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd";
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
......@@ -46,28 +46,6 @@ public:
.def_ro("memory_order", &AtomicAddNode::memory_order);
}
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width) &&
equal(memory_order, other->memory_order);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
hash_reduce(memory_order);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
protected:
/// Create SIMT-style parallel loop structure
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
......@@ -85,7 +63,8 @@ protected:
/// Wrapper class for atomic addition operations
class AtomicAdd : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -93,4 +72,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ATOMIC_ADD_H_
\ No newline at end of file
#endif // TVM_TL_OP_ATOMIC_ADD_H_
......@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -165,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -219,6 +234,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -286,11 +306,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -301,5 +326,20 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment