"src/vscode:/vscode.git/clone" did not exist on "82b60de915f91a939f0f4e11e39f6e0279ce2dce"
Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
set -eux #!/usr/bin/env bash
set -euxo pipefail
# Get the CUDA version from the command line # Build for local architecture
IMAGE="tilelang-builder:manylinux" CIBW_BUILD='cp38-*' cibuildwheel .
docker build . -f "$(dirname "${BASH_SOURCE[0]}")/pypi.manylinux.Dockerfile" --tag ${IMAGE}
script="sh maint/scripts/local_distribution.sh"
docker run --rm -v $(pwd):/tilelang ${IMAGE} /bin/bash -c "$script"
set -eux #!/usr/bin/env bash
set -euxo pipefail
# Get the CUDA version from the command line if docker buildx version >/dev/null 2>&1; then
IMAGE="tilelang-builder:manylinux" if docker info >/dev/null 2>&1; then
docker build . -f "$(dirname "${BASH_SOURCE[0]}")/pypi.manylinux.Dockerfile" --tag ${IMAGE} docker run --rm --privileged tonistiigi/binfmt --install amd64,arm64 >/dev/null 2>&1 || true
fi
script="sh maint/scripts/pypi_distribution.sh" if ! docker buildx inspect multi >/dev/null 2>&1; then
docker buildx create --name multi --driver docker-container --use >/dev/null 2>&1 || true
else
docker buildx use multi >/dev/null 2>&1 || true
fi
docker buildx inspect --bootstrap >/dev/null 2>&1 || true
done
docker run --rm -v $(pwd):/tilelang -w /tilelang ${IMAGE} /bin/bash -c "$script" export CIBW_ARCHS='x86_64 aarch64'
fi
NO_VERSION_LABEL=ON CIBW_BUILD='cp38-*' cibuildwheel .
FROM pytorch/manylinux2_28-builder:cuda12.1 AS builder_amd64 FROM quay.io/pypa/manylinux2014_x86_64 AS builder_amd64
ENV CUDA_VERSION=12.1 \
AUDITWHEEL_PLAT=manylinux_2_28_x86_64
RUN pip3 install uv
FROM pytorch/manylinuxaarch64-builder:cuda12.8 AS builder_arm64 RUN yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
ENV CUDA_VERSION=12.8 \
AUDITWHEEL_PLAT=manylinux_2_28_aarch64
ARG CUDA_VERSION=12.1
ENV CUDA_VERSION=${CUDA_VERSION}
FROM quay.io/pypa/manylinux_2_28_aarch64 AS builder_arm64
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ARG CUDA_VERSION=12.8
ENV CUDA_VERSION=${CUDA_VERSION}
ARG TARGETARCH
FROM builder_${TARGETARCH} FROM builder_${TARGETARCH}
ENV DEBIAN_FRONTEND=noninteractive \ ENV DEBIAN_FRONTEND=noninteractive \
TZ=Etc/UTC TZ=Etc/UTC
RUN set -eux; \ ENV PATH="/usr/local/cuda/bin:${PATH}"
uv venv -p 3.12 --seed /venv; \
git config --global --add safe.directory '/tilelang'
ENV PATH="/venv/bin:$PATH" \ ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
VIRTUAL_ENV=/venv
RUN uv pip install build wheel RUN set -eux; \
pipx install cibuildwheel; \
git config --global --add safe.directory '/tilelang'
WORKDIR /tilelang WORKDIR /tilelang
set -eux set -eux
rm -rf dist rm -rf dist raw_dist
python -mpip install -U pip python -mpip install -U pip
python -mpip install -U build wheel auditwheel patchelf python -mpip install -U build wheel auditwheel patchelf
......
...@@ -2,27 +2,32 @@ ...@@ -2,27 +2,32 @@
name = "tilelang" name = "tilelang"
description = "A tile level programming language to generate high performance code." description = "A tile level programming language to generate high performance code."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.9"
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }] authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }] maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
license = "MIT" license = "MIT"
keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"] keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta",
"Environment :: GPU", "Environment :: GPU",
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
"Operating System :: OS Independent",
"Operating System :: MacOS", "Operating System :: MacOS",
"Programming Language :: Python :: 3.8", "Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
] ]
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"apache-tvm-ffi==0.1.0",
"cloudpickle", "cloudpickle",
"ml-dtypes", "ml-dtypes",
"numpy>=1.23.5", "numpy>=1.23.5",
...@@ -39,11 +44,7 @@ dependencies = [ ...@@ -39,11 +44,7 @@ dependencies = [
fp4 = ["ml-dtypes>=0.5.1"] fp4 = ["ml-dtypes>=0.5.1"]
[build-system] [build-system]
requires = [ requires = ["cython>=3.0.0", "scikit-build-core"]
"cython>=3.0.0",
"scikit-build-core",
"setuptools>=63",
]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[tool.scikit-build] [tool.scikit-build]
...@@ -58,11 +59,14 @@ metadata.version.provider = "version_provider" ...@@ -58,11 +59,14 @@ metadata.version.provider = "version_provider"
metadata.version.provider-path = "." metadata.version.provider-path = "."
experimental = true experimental = true
# build.verbose = true
# logging.level = "DEBUG"
[tool.scikit-build.sdist] [tool.scikit-build.sdist]
# See MANIFEST.in for details
include = [ include = [
"VERSION", "./VERSION",
"LICENSE", ".git_commit.txt",
"./LICENSE",
"THIRDPARTYNOTICES.txt", "THIRDPARTYNOTICES.txt",
"version_provider.py", "version_provider.py",
"requirements*.txt", "requirements*.txt",
...@@ -70,7 +74,15 @@ include = [ ...@@ -70,7 +74,15 @@ include = [
"CMakeLists.txt", "CMakeLists.txt",
"src/**", "src/**",
"cmake/**", "cmake/**",
"3rdparty/**", # The vendored 3rdparty contents in sdist should be same as wheel.
# Need full TVM to build from source.
"3rdparty/tvm",
# CUTLASS
"3rdparty/cutlass/include",
"3rdparty/cutlass/tools",
# Composable Kernel
"3rdparty/composable_kernel/include",
"3rdparty/composable_kernel/library",
"testing/**", "testing/**",
"examples/**", "examples/**",
] ]
...@@ -79,8 +91,7 @@ exclude = [ ...@@ -79,8 +91,7 @@ exclude = [
".github", ".github",
"**/.git", "**/.git",
"**/.github", "**/.github",
"3rdparty/clang**", "3rdparty/**",
"3rdparty/llvm**",
"build", "build",
] ]
...@@ -89,7 +100,17 @@ tilelang = "tilelang" ...@@ -89,7 +100,17 @@ tilelang = "tilelang"
"tilelang/src" = "src" "tilelang/src" = "src"
# NOTE: The mapping below places the contents of '3rdparty' inside 'tilelang/3rdparty' in the wheel. # NOTE: The mapping below places the contents of '3rdparty' inside 'tilelang/3rdparty' in the wheel.
# This is necessary to find TVM shared libraries at runtime. # This is necessary to find TVM shared libraries at runtime.
"tilelang/3rdparty" = "3rdparty" # The vendored 3rdparty contents in wheel should be same as sdist.
# TVM
"tilelang/3rdparty/tvm/src" = "3rdparty/tvm/src"
"tilelang/3rdparty/tvm/python" = "3rdparty/tvm/python"
"tilelang/3rdparty/tvm/version.py" = "3rdparty/tvm/version.py"
# CUTLASS
"tilelang/3rdparty/cutlass/include" = "3rdparty/cutlass/include"
"tilelang/3rdparty/cutlass/tools" = "3rdparty/cutlass/tools"
# Composable Kernel
"tilelang/3rdparty/composable_kernel/include" = "3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library" = "3rdparty/composable_kernel/library"
[tool.yapf] [tool.yapf]
based_on_style = "yapf" based_on_style = "yapf"
...@@ -106,7 +127,7 @@ skip = [ ...@@ -106,7 +127,7 @@ skip = [
] ]
[tool.ruff] [tool.ruff]
target-version = "py38" target-version = "py39"
line-length = 100 line-length = 100
output-format = "full" output-format = "full"
...@@ -170,32 +191,45 @@ build-frontend = "build" ...@@ -170,32 +191,45 @@ build-frontend = "build"
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" } environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" }
environment-pass = [ environment-pass = [
"CUDA_VERSION", "CUDA_VERSION",
"NO_VERSION_LABEL",
"NO_TOOLCHAIN_VERSION",
"NO_GIT_VERSION",
"COLUMNS", "COLUMNS",
"CMAKE_GENERATOR",
"CMAKE_BUILD_PARALLEL_LEVEL",
"FORCE_COLOR", "FORCE_COLOR",
"CLICOLOR_FORCE", "CLICOLOR_FORCE",
] ]
before-build = "env -0 | sort -z | tr '\\0' '\\n'" before-build = "env -0 | sort -z | tr '\\0' '\\n'"
windows.before-build = "set" windows.before-build = "set"
# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now test-command = [
manylinux-x86_64-image = "manylinux2014" "python -c 'import tilelang; print(tilelang.__version__)'",
manylinux-aarch64-image = "manylinux_2_28" ]
[tool.cibuildwheel.linux] [tool.cibuildwheel.linux]
environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" } environment.PYTHONDEVMODE = "1"
repair-wheel-command = [ environment.PYTHONUNBUFFERED = "1"
"auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", environment.PATH = "/usr/local/cuda/bin:$PATH"
"pipx run abi3audit --strict --report {wheel}", environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
] # Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now
# TODO: upgrade to manylinux_2_28 at some time
manylinux-x86_64-image = "manylinux2014" # CentOS 7
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
# Install CUDA runtime and stub driver library # Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA 12.8 # manylinux_2_28 uses gcc 14, which needs CUDA 12.8
before-all = """ before-all = """
set -eux set -eux
cat /etc/*-release
uname -a
case "$(uname -m)" in case "$(uname -m)" in
"x86_64") "x86_64")
DEFAULT_CUDA_VERSION="12.1"
yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
;; ;;
"aarch64") "aarch64")
DEFAULT_CUDA_VERSION="12.8"
dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
;; ;;
*) *)
...@@ -203,7 +237,24 @@ case "$(uname -m)" in ...@@ -203,7 +237,24 @@ case "$(uname -m)" in
;; ;;
esac esac
cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)" cudaver="$(echo "${CUDA_VERSION:-$DEFAULT_CUDA_VERSION}" | cut -d '.' -f-2)"
v="${cudaver//./-}" v="${cudaver//./-}"
yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs
""" """
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[tool.cibuildwheel.macos]
repair-wheel-command = [
"delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
[[tool.cibuildwheel.overrides]]
select = "*linux*x86_64*"
# CentOS 7 is too old to run import test. Do wheel installation test only.
test-command = [
"echo 'Wheel is installed successfully'",
]
# Requirements to run local build with `--no-build-isolation` or other developments # Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi~=0.1.0
build build
cmake>=3.26 cmake>=3.26
cython>=3.0.0 cython>=3.0.0
......
...@@ -3,5 +3,5 @@ pre-commit ...@@ -3,5 +3,5 @@ pre-commit
clang-format==21.1.2 clang-format==21.1.2
clang-tidy==21.1.1 clang-tidy==21.1.1
codespell[toml]==2.4.1 codespell[toml]==2.4.1
ruff==0.14.1 ruff==0.14.3
yapf==0.43.0 yapf==0.43.0
...@@ -18,10 +18,11 @@ cython ...@@ -18,10 +18,11 @@ cython
docutils docutils
dtlib dtlib
einops einops
flash-linear-attention==0.3.2
packaging>=21.0 packaging>=21.0
pytest-xdist>=2.2.1
pytest-durations pytest-durations
pytest-timeout pytest-timeout
pytest-xdist>=2.2.1
pytest>=6.2.4 pytest>=6.2.4
pyyaml pyyaml
requests requests
......
# Runtime requirements # Runtime requirements
apache-tvm-ffi~=0.1.0
cloudpickle cloudpickle
ml-dtypes ml-dtypes
numpy>=1.23.5 numpy>=1.23.5
...@@ -7,4 +8,3 @@ torch ...@@ -7,4 +8,3 @@ torch
torch>=2.7; platform_system == 'Darwin' torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3 tqdm>=4.62.3
typing-extensions>=4.10.0 typing-extensions>=4.10.0
flash-linear-attention==0.3.2
\ No newline at end of file
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include "./transform/common/attr.h" #include "./transform/common/attr.h"
#include "op/builtin.h" #include "op/builtin.h"
#include "tvm/ffi/any.h" #include "tvm/ffi/any.h"
#include <tvm/ffi/object.h>
#include "support/ffi_aliases.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name, dom->dtype); Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain. // Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.push_back(var); n->vars.push_back(var);
n->doms.push_back(Range(0, dom)); n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
ForFrame ParallelFor(const Array<PrimExpr> &extents, ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) { const Map<String, ObjectRef> &annotations) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size()); n->vars.reserve(extents.size());
n->doms.reserve(extents.size()); n->doms.reserve(extents.size());
for (const auto &extent : extents) { for (const auto &extent : extents) {
...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, ...@@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
const Array<Array<PrimExpr>> &sync, const Array<Array<PrimExpr>> &sync,
const Array<Array<PrimExpr>> &groups) { const Array<Array<PrimExpr>> &groups) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
DataType dtype = stop.dtype(); DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype)); n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop)); n->doms.push_back(Range(std::move(start), stop));
...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size, ...@@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
const PrimExpr &index, PrimExpr group_size) { const PrimExpr &index, PrimExpr group_size) {
using namespace tvm::tir; using namespace tvm::tir;
ICHECK(!domain.empty()); ICHECK(!domain.empty());
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(domain.size()); n->vars.reserve(domain.size());
n->doms.reserve(domain.size()); n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0]; PrimExpr domain_size = domain[0];
...@@ -193,8 +196,8 @@ public: ...@@ -193,8 +196,8 @@ public:
"frames", &KernelLaunchFrameNode::frames); "frames", &KernelLaunchFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.KernelLaunchFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode); KernelLaunchFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -218,14 +221,20 @@ public: ...@@ -218,14 +221,20 @@ public:
*/ */
class KernelLaunchFrame : public TIRFrame { class KernelLaunchFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame, explicit KernelLaunchFrame(ObjectPtr<KernelLaunchFrameNode> data)
: TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode); KernelLaunchFrameNode);
}; };
KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
const Optional<Array<PrimExpr>> &block_size_opt, const Optional<Array<PrimExpr>> &block_size_opt,
const Map<String, ffi::Any> &attrs) { const Map<String, ffi::Any> &attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>(); ObjectPtr<KernelLaunchFrameNode> n =
tvm::ffi::make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads. // If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame = bool is_cpu_kernel_frame =
...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size, ...@@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array<PrimExpr> &grid_size,
return KernelLaunchFrame(n); return KernelLaunchFrame(n);
} }
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("tl.Parallel", ParallelFor) .def("tl.Parallel", ParallelFor)
.def("tl.Pipelined", PipelinedFor) .def("tl.Pipelined", PipelinedFor)
.def("tl.Persistent", PersistentFor) .def("tl.Persistent", PersistentFor)
.def("tl.KernelLaunch", KernelLaunch); .def("tl.KernelLaunch", KernelLaunch);
}); }
class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrameNode : public TIRFrameNode {
public: public:
...@@ -310,8 +317,8 @@ public: ...@@ -310,8 +317,8 @@ public:
"frames", &WarpSpecializeFrameNode::frames); "frames", &WarpSpecializeFrameNode::frames);
} }
static constexpr const char *_type_key = "tl.WarpSpecializeFrame"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame",
TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode); WarpSpecializeFrameNode, TIRFrameNode);
public: public:
TVM_DLL void EnterWithScope() final { TVM_DLL void EnterWithScope() final {
...@@ -330,15 +337,20 @@ public: ...@@ -330,15 +337,20 @@ public:
class WarpSpecializeFrame : public TIRFrame { class WarpSpecializeFrame : public TIRFrame {
public: public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame, explicit WarpSpecializeFrame(ObjectPtr<WarpSpecializeFrameNode> data)
TIRFrame, : TIRFrame(::tvm::ffi::UnsafeInit{}) {
ICHECK(data != nullptr);
data_ = std::move(data);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame,
WarpSpecializeFrameNode); WarpSpecializeFrameNode);
}; };
WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
const PrimExpr &thread_idx, const PrimExpr &thread_idx,
int warp_group_size = 128) { int warp_group_size = 128) {
ObjectPtr<WarpSpecializeFrameNode> n = make_object<WarpSpecializeFrameNode>(); ObjectPtr<WarpSpecializeFrameNode> n =
tvm::ffi::make_object<WarpSpecializeFrameNode>();
PrimExpr condition; PrimExpr condition;
std::vector<int> warp_groups; std::vector<int> warp_groups;
warp_groups.reserve(warp_group_ids.size()); warp_groups.reserve(warp_group_ids.size());
...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids, ...@@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array<IntImm> &warp_group_ids,
return WarpSpecializeFrame(n); return WarpSpecializeFrame(n);
} }
TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize); refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize);
KernelLaunchFrameNode::RegisterReflection(); KernelLaunchFrameNode::RegisterReflection();
WarpSpecializeFrameNode::RegisterReflection(); WarpSpecializeFrameNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -594,11 +594,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) { ...@@ -594,11 +594,11 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
bool k_inner) { bool k_inner) {
if (k_inner) if (k_inner && continuous % 32 == 0 && stride % 32 == 0)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous); return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0) if (is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous); return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0) if (!is_a && continuous % 64 == 0 && stride % 4 == 0)
return MakeGemmVoltaBLayoutCongruous(stride, continuous); return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16); return makeGemmABLayoutPadded(stride, continuous, 16);
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "layout.h" #include "layout.h"
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/arith/pattern.h> #include <tvm/arith/pattern.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -64,13 +65,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) { ...@@ -64,13 +65,12 @@ Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
} }
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) { Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index); auto n = tvm::ffi::make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const { ...@@ -102,10 +102,24 @@ Array<PrimExpr> LayoutNode::OutputShape() const {
for (size_t i = 0; i < ret.size(); i++) { for (size_t i = 0; i < ret.size(); i++) {
auto ist = analyzer.int_set(forward_index_[i] + 1); auto ist = analyzer.int_set(forward_index_[i] + 1);
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) { if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression // Analyzer couldn't form an IntervalSet (e.g. bitwise ops).
// Fall back to ConstIntBound to derive a safe extent.
auto cib = analyzer.const_int_bound(forward_index_[i]);
if (cib->min_value != arith::ConstIntBound::kNegInf &&
cib->max_value != arith::ConstIntBound::kPosInf &&
cib->min_value >= 0) {
// extent = max - min + 1, using 64-bit integer literal
ret.Set(i, Integer(cib->max_value - cib->min_value + 1));
} else {
// Last-resort conservative fallback to avoid OOB/crash
// Prefer to keep dimension from known input_size_ if available.
if (i < input_size_.size()) {
ret.Set(i, input_size_[i]); ret.Set(i, input_size_[i]);
} else { } else {
// CHECK(is_one(ist.min())) << ist.min(); ret.Set(i, Integer(1));
}
}
} else {
ret.Set(i, ist.max()); ret.Set(i, ist.max());
} }
} }
...@@ -130,7 +144,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const { ...@@ -130,7 +144,6 @@ Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
Array<PrimExpr> transformed = forward_index_.Map( Array<PrimExpr> transformed = forward_index_.Map(
[&](const PrimExpr &e) { return Substitute(e, vmap); }); [&](const PrimExpr &e) { return Substitute(e, vmap); });
// Concatenate with the remaining elements from vars // Concatenate with the remaining elements from vars
Array<PrimExpr> result; Array<PrimExpr> result;
for (size_t i = 0; i < vars.size() - InputDim(); i++) { for (size_t i = 0; i < vars.size() - InputDim(); i++) {
...@@ -212,7 +225,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -212,7 +225,7 @@ Fragment FragmentNode::DeReplicate() const {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size); factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
} }
if (factor == 1) if (factor == 1)
return GetRef<Fragment>(this); return tvm::ffi::GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor + vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
...@@ -224,7 +237,7 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -224,7 +237,7 @@ Fragment FragmentNode::DeReplicate() const {
} }
Fragment FragmentNode::BindThreadRange(Range thread_range) const { Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this); auto n = tvm::ffi::make_object<FragmentNode>(*this);
n->thread_range_ = thread_range; n->thread_range_ = thread_range;
return Fragment(n); return Fragment(n);
} }
...@@ -251,14 +264,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -251,14 +264,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
if (!is_static_shape) { if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and // Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn. // warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to " DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: " "NoCheck; symbolic dims: "
<< symbolic_dims; << symbolic_dims;
} }
arith::IterMapResult res = arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer); arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
ICHECK(res->errors.empty()) if (!res->errors.empty()) {
<< "Layout " << DebugOutput() << " has errors: " << res->errors; std::ostringstream msg;
msg << "Layout " << DebugOutput() << " has errors: " << res->errors;
throw NormalizeIterException(msg.str());
}
auto outputs_shape = OutputShape(); auto outputs_shape = OutputShape();
Array<PrimExpr> outputs; Array<PrimExpr> outputs;
...@@ -280,10 +296,170 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -280,10 +296,170 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
return {Layout(outputs_shape, backward_index), level}; return {Layout(outputs_shape, backward_index), level};
} }
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}
// Step 1. Prove the product of InputShape is equal to the product of shape
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
}
PrimExpr shape_product = Integer(1);
for (const auto &dim : shape) {
shape_product *= dim;
}
// Use provided analyzer if present, otherwise a local fallback to avoid
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
// Step 3. Compute the flat index from new shape indices
// flat_index = k0 * (s1 * s2 * ...) + k1 * (s2 * s3 * ...) + ... + kn
PrimExpr flat_index = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j) {
stride = stride * shape[j];
}
flat_index = flat_index + new_vars[i] * stride;
}
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
stride = stride * InputShape()[j];
}
original_indices.push_back(floordiv(remaining, stride));
remaining = floormod(remaining, stride);
}
// Step 5. Substitute original indices into forward_index_
Array<PrimExpr> new_forward_index;
for (const auto &fwd_expr : forward_index_) {
PrimExpr substituted = fwd_expr;
// Replace each InputPlaceholder(i) with original_indices[i]
for (size_t i = 0; i < InputShape().size(); ++i) {
substituted =
Substitute(substituted, {{InputPlaceholder(i), original_indices[i]}});
}
new_forward_index.push_back(az->Simplify(substituted));
}
for (size_t i = 0; i < new_vars.size(); ++i) {
new_forward_index =
Substitute(new_forward_index, {{new_vars[i], InputPlaceholder(i)}});
}
return Layout(shape, new_forward_index);
}
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
}
// 1) Prove total number of elements remains the same
PrimExpr input_prod = Integer(1);
for (const auto &d : InputShape())
input_prod *= d;
PrimExpr shape_prod = Integer(1);
for (const auto &d : shape)
shape_prod *= d;
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices
Array<Var> new_vars;
new_vars.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
// Cannot use InputPlaceholder(i) here, because it would cause name capture
// (variable capture) with InputPlaceholder(i) in upper scopes. Therefore,
// we must create a fresh variable here to avoid confusion when
// substituting.
auto var = Var(std::string("n_") + std::to_string(i), shape[i].dtype());
az->Bind(var, Range(0, shape[i]));
new_vars.push_back(var);
}
PrimExpr flat = Integer(0);
for (size_t i = 0; i < shape.size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < shape.size(); ++j)
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
stride = stride * InputShape()[j];
orig_indices.push_back(floordiv(remain, stride));
remain = floormod(remain, stride);
}
// 4) Substitute old placeholders with expressions of new indices
Array<PrimExpr> new_forward_index;
for (const auto &e : forward_index_) {
PrimExpr cur = e;
for (size_t i = 0; i < InputShape().size(); ++i) {
cur = Substitute(cur, {{InputPlaceholder(i), orig_indices[i]}});
}
cur = az->Simplify(cur);
new_forward_index.push_back(cur);
}
PrimExpr new_forward_thread = forward_thread_;
for (size_t i = 0; i < InputShape().size(); ++i) {
new_forward_thread = Substitute(new_forward_thread,
{{InputPlaceholder(i), orig_indices[i]}});
}
new_forward_thread = az->Simplify(new_forward_thread);
for (size_t i = 0; i < new_vars.size(); ++i) {
auto var = new_vars[i];
new_forward_index =
Substitute(new_forward_index, {{var, InputPlaceholder(i)}});
new_forward_thread =
Substitute(new_forward_thread, {{var, InputPlaceholder(i)}});
}
Fragment reshaped(shape, new_forward_index, new_forward_thread,
ReplicateExtent(), std::nullopt);
if (thread_range_.defined()) {
reshaped = reshaped->BindThreadRange(thread_range_);
}
return reshaped;
}
Layout LayoutNode::Inverse() const { Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel(); auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first); return std::move(inverse_result.first);
} }
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters, PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread, const PrimExpr &forward_thread,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
...@@ -336,8 +512,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index, ...@@ -336,8 +512,8 @@ Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap); forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -348,8 +524,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -348,8 +524,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
forward_thread = Substitute( forward_thread = Substitute(
forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
} }
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, auto n = tvm::ffi::make_object<FragmentNode>(input_size, forward_index,
replicate_size); forward_thread, replicate_size);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -442,21 +618,6 @@ std::string FragmentNode::DebugOutput() const { ...@@ -442,21 +618,6 @@ std::string FragmentNode::DebugOutput() const {
return ss.str(); return ss.str();
} }
bool LayoutNode::SEqualReduce(const LayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode *other,
SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const { bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
...@@ -495,10 +656,7 @@ void FragmentNode::RegisterReflection() { ...@@ -495,10 +656,7 @@ void FragmentNode::RegisterReflection() {
.def_ro("replicate_size", &FragmentNode::replicate_size_); .def_ro("replicate_size", &FragmentNode::replicate_size_);
} }
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_FFI_STATIC_INIT_BLOCK() {
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def_packed("tl.Layout", .def_packed("tl.Layout",
...@@ -560,12 +718,23 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -560,12 +718,23 @@ TVM_FFI_STATIC_INIT_BLOCK({
element_size, k_inner); element_size, k_inner);
} }
}) })
.def("tl.make_volta_swizzled_layout",
[](int stride, int mat_continuous, bool is_a, bool k_inner) {
return makeGemmVoltaABLayout(stride, mat_continuous, is_a,
k_inner);
})
.def("tl.make_wgmma_swizzled_layout", .def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size, [](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) { bool k_inner) {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity, return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner); element_size, k_inner);
}) })
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout", .def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) { [](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size); return makeFullBankSwizzleLayout(stride, continuous, element_size);
...@@ -582,13 +751,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -582,13 +751,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.make_linear_layout", [](int stride, int continuous) { .def("tl.make_linear_layout", [](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous); return makeGemmLayoutLinear(stride, continuous);
}); });
}); }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
LayoutNode::RegisterReflection(); LayoutNode::RegisterReflection();
FragmentNode::RegisterReflection(); FragmentNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -8,8 +8,11 @@ ...@@ -8,8 +8,11 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
#include <utility> #include <utility>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -38,17 +41,20 @@ public: ...@@ -38,17 +41,20 @@ public:
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const; virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const; virtual Layout Inverse() const;
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const; virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
virtual std::string DebugOutput() const; virtual std::string DebugOutput() const;
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
protected: protected:
virtual Map<Var, Range> getVarMap() const; virtual Map<Var, Range> getVarMap() const;
...@@ -65,7 +71,7 @@ public: ...@@ -65,7 +71,7 @@ public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index); TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode);
}; };
class FragmentNode : public LayoutNode { class FragmentNode : public LayoutNode {
...@@ -79,6 +85,9 @@ public: ...@@ -79,6 +85,9 @@ public:
Array<PrimExpr> GetForwardVars() const final; Array<PrimExpr> GetForwardVars() const final;
Layout Inverse() const final; Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final; std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
PrimExpr ThreadExtent() const; PrimExpr ThreadExtent() const;
...@@ -109,9 +118,9 @@ public: ...@@ -109,9 +118,9 @@ public:
static void RegisterReflection(); static void RegisterReflection();
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
static constexpr const char *_type_key = "tl.Fragment"; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); kTVMFFISEqHashKindTreeNode;
protected: protected:
Map<Var, Range> getVarMap() const final; Map<Var, Range> getVarMap() const final;
...@@ -132,7 +141,7 @@ public: ...@@ -132,7 +141,7 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size, PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var); Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
}; };
Var InputPlaceholder(size_t idx); Var InputPlaceholder(size_t idx);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "swizzle.h" #include "swizzle.h"
#include <tvm/node/node.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, ...@@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
forward_index = forward_index =
forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern); auto n = tvm::ffi::make_object<SwizzledLayoutNode>(input_size, forward_index,
pattern);
data_ = std::move(n); data_ = std::move(n);
} }
...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() { ...@@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() {
refl::ObjectDef<SwizzledLayoutNode>(); refl::ObjectDef<SwizzledLayoutNode>();
} }
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other,
SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) &&
pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -44,10 +44,9 @@ public: ...@@ -44,10 +44,9 @@ public:
Layout Inverse() const final; Layout Inverse() const final;
std::string DebugOutput() const final; std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
static void RegisterReflection(); static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode,
LayoutNode);
private: private:
SwizzlePattern pattern_; SwizzlePattern pattern_;
...@@ -62,8 +61,8 @@ public: ...@@ -62,8 +61,8 @@ public:
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, TVM_DLL SwizzledLayout(Array<PrimExpr> input_size,
Array<PrimExpr> forward_index, SwizzlePattern pattern); Array<PrimExpr> forward_index, SwizzlePattern pattern);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout,
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode); SwizzledLayoutNode);
}; };
} // namespace tl } // namespace tl
......
...@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark, ...@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
return results; return results;
} }
// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator
// appears in fused forms across multiple index expressions. We first normalize
// every index into IterSumExpr, collect all splits per source Var, then
// consolidate them to avoid misclassifying a used split as unused.
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters, const Array<IterVar> input_iters,
Analyzer *analyzer) { Analyzer *analyzer) {
...@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, ...@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
} }
for (const IterVar &iter : input_iters) { for (const IterVar &iter : input_iters) {
IterMark iv_mark; // Merge splits from all IterMark that share the same source Var as `iter`.
std::vector<IterSplitExpr> merged_splits;
for (const IterMark &mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>()->same_as(iter->var)) { // NOLINT(*) auto vexpr = mark->source.as<Var>();
iv_mark = mark; if (vexpr && vexpr.value().same_as(iter->var)) {
break; auto it = collector.mark2splits_.find(mark);
if (it != collector.mark2splits_.end()) {
const auto &vec = it->second;
merged_splits.insert(merged_splits.end(), vec.begin(), vec.end());
}
} }
} }
if (iv_mark.defined()) {
auto splits = if (!merged_splits.empty()) {
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer); // Use a unified mark (Var + full extent) to compute the missing pieces
// Put the small axis last // so that fused usages are honored as "used" and not reintroduced.
IterMark unified_mark(iter->var, iter->dom->extent);
auto splits = get_unused_iters(unified_mark, merged_splits, analyzer);
// Put the small axis last for a flattened ordering.
results.insert(results.end(), splits.rbegin(), splits.rend()); results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) { } else if (!is_one(iter->dom->extent)) {
auto mark = IterMark(iter->var, iter->dom->extent); auto mark = IterMark(iter->var, iter->dom->extent);
...@@ -189,7 +201,7 @@ public: ...@@ -189,7 +201,7 @@ public:
IterMark Mutate(const IterMark &mark) { IterMark Mutate(const IterMark &mark) {
if (auto *op = mark->source.as<IterSumExprNode>()) { if (auto *op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent); return IterMark(Mutate(tvm::ffi::GetRef<IterSumExpr>(op)), mark->extent);
} else { } else {
return mark; return mark;
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
...@@ -42,7 +42,7 @@ using namespace tir; ...@@ -42,7 +42,7 @@ using namespace tir;
* - The constructed node is stored in this->data_. * - The constructed node is stored in this->data_.
*/ */
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>(); ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { ...@@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned AtomicAddNode. * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/ */
TileOperator AtomicAddNode::Clone() const { TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this); auto op = tvm::ffi::make_object<AtomicAddNode>(*this);
if (par_op_.defined()) { if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone()); op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
} }
...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -25,8 +25,8 @@ public: ...@@ -25,8 +25,8 @@ public:
IntImm memory_order; ///< Memory order for atomic operations IntImm memory_order; ///< Memory order for atomic operations
mutable ParallelOp par_op_; ///< Associated parallel operation mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
...@@ -46,28 +46,6 @@ public: ...@@ -46,28 +46,6 @@ public:
.def_ro("memory_order", &AtomicAddNode::memory_order); .def_ro("memory_order", &AtomicAddNode::memory_order);
} }
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width) &&
equal(memory_order, other->memory_order);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
hash_reduce(memory_order);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
protected: protected:
/// Create SIMT-style parallel loop structure /// Create SIMT-style parallel loop structure
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
...@@ -85,7 +63,8 @@ protected: ...@@ -85,7 +63,8 @@ protected:
/// Wrapper class for atomic addition operations /// Wrapper class for atomic addition operations
class AtomicAdd : public TileOperator { class AtomicAdd : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap); TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) ...@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -165,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory) ...@@ -165,6 +175,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -219,6 +234,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait) ...@@ -219,6 +234,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx) TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -286,11 +306,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect) ...@@ -286,11 +306,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(initialize_descriptor) TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -301,5 +326,20 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) ...@@ -301,5 +326,20 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment