Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
#!/usr/bin/env bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and ruff'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
RUFF_VERSION=$(ruff --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version)
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)"
echo 'tile-lang yapf: Check Start'
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' '3rdparty/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'tile-lang yapf: Done'
echo 'tile-lang codespell: Check Start'
# check spelling of specified files
spell_check() {
codespell "$@"
}
spell_check_all(){
codespell --toml pyproject.toml
}
# Spelling check of files that differ from main branch.
spell_check_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause ruff to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell
fi
}
# Run Codespell
## This flag runs spell check of individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
spell_check "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
spell_check_all
else
# Check spelling only of the files that changed in last commit.
spell_check_changed
fi
echo 'tile-lang codespell: Done'
echo 'tile-lang ruff: Check Start'
# Lint specified files
lint() {
ruff check "$@"
}
# Lint files that differ from main branch. Ignores dirs that are not slated
# for autolint yet.
lint_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause ruff to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
ruff check
fi
}
# Run Ruff
### This flag lints individual files. --files *must* be the first command line
### arg to use this option.
if [[ "$1" == '--files' ]]; then
lint "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
lint python testing
else
# Format only the files that changed in last commit.
lint_changed
fi
echo 'tile-lang ruff: Done'
echo 'tile-lang clang-format: Check Start'
# If clang-format is available, run it; otherwise, skip
if command -v clang-format &>/dev/null; then
CLANG_FORMAT_VERSION=$(clang-format --version | awk '{print $3}')
tool_version_check "clang-format" "$CLANG_FORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)"
CLANG_FORMAT_FLAGS=("-i")
# Apply clang-format to specified files
clang_format() {
clang-format "${CLANG_FORMAT_FLAGS[@]}" "$@"
}
# Format all C/C++ files in the repo, excluding specified directories
clang_format_all() {
find . -type f \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' \) \
-not -path "./3rdparty/*" \
-not -path "./build/*" \
-exec clang-format -i {} +
}
# Format changed C/C++ files relative to main
clang_format_changed() {
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' | xargs clang-format -i
fi
}
if [[ "$1" == '--files' ]]; then
# If --files is given, format only the provided files
clang_format "${@:2}"
elif [[ "$1" == '--all' ]]; then
# If --all is given, format all eligible C/C++ files
clang_format_all
else
# Otherwise, format only changed C/C++ files
clang_format_changed
fi
else
echo "clang-format not found. Skipping C/C++ formatting."
fi
echo 'tile-lang clang-format: Done'
# Check if there are any uncommitted changes after all formatting steps.
# If there are, ask the user to review and stage them.
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi
# Check if clang-tidy is installed and get the version
if command -v clang-tidy &>/dev/null; then
CLANG_TIDY_VERSION=$(clang-tidy --version | head -n 1 | awk '{print $3}')
tool_version_check "clang-tidy" "$CLANG_TIDY_VERSION" "$(grep clang-tidy requirements-dev.txt | cut -d'=' -f3)"
else
echo "clang-tidy not found. Skipping C++ static analysis."
CLANG_TIDY_AVAILABLE=false
fi
# Function to run clang-tidy
clang_tidy() {
clang-tidy "$@" -- -std=c++17
}
# Run clang-tidy on all C/C++ files
clang_tidy_all() {
find . -type f \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hpp' \) \
-not -path "./3rdparty/*" -not -path "./build/*" \
| xargs -n 1 clang-tidy -- -std=c++17
}
# Run clang-tidy on changed C/C++ files relative to main
clang_tidy_changed() {
if git show-ref --verify --quiet refs/remotes/origin/main; then
BASE_BRANCH="origin/main"
else
BASE_BRANCH="main"
fi
MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.c' '*.cc' '*.cpp' '*.h' '*.hpp' | xargs -n 1 clang-tidy -- -std=c++17
fi
}
# Add clang-tidy support to the main script logic
echo 'tile-lang clang-tidy: Check Start'
if [[ "$CLANG_TIDY_AVAILABLE" != false ]]; then
if [[ "$1" == '--files' ]]; then
# If --files is given, analyze only the provided files
clang_tidy "${@:2}"
elif [[ "$1" == '--all' ]]; then
# If --all is given, analyze all eligible C/C++ files
clang_tidy_all
else
# Otherwise, analyze only changed C/C++ files
clang_tidy_changed
fi
else
echo "clang-tidy is not available. Skipping static analysis."
fi
echo 'tile-lang clang-tidy: Done'
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi
echo 'tile-lang: All checks passed'
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
echo "Starting installation script..."
# Step 1: Install Python requirements
echo "Installing Python requirements from requirements.txt..."
pip install -r requirements.txt
if [ $? -ne 0 ]; then
echo "Error: Failed to install Python requirements."
exit 1
else
echo "Python requirements installed successfully."
fi
# Step 2: Define LLVM version and architecture
LLVM_VERSION="10.0.1"
IS_AARCH64=false
EXTRACT_PATH="3rdparty"
echo "LLVM version set to ${LLVM_VERSION}."
echo "Is AARCH64 architecture: $IS_AARCH64"
# Step 3: Determine the correct Ubuntu version based on LLVM version
UBUNTU_VERSION="16.04"
if [[ "$LLVM_VERSION" > "17.0.0" ]]; then
UBUNTU_VERSION="22.04"
elif [[ "$LLVM_VERSION" > "16.0.0" ]]; then
UBUNTU_VERSION="20.04"
elif [[ "$LLVM_VERSION" > "13.0.0" ]]; then
UBUNTU_VERSION="18.04"
fi
echo "Ubuntu version for LLVM set to ${UBUNTU_VERSION}."
# Step 4: Set download URL and file name for LLVM
BASE_URL="https://github.com/llvm/llvm-project/releases/download/llvmorg-${LLVM_VERSION}"
if $IS_AARCH64; then
FILE_NAME="clang+llvm-${LLVM_VERSION}-aarch64-linux-gnu.tar.xz"
else
FILE_NAME="clang+llvm-${LLVM_VERSION}-x86_64-linux-gnu-ubuntu-${UBUNTU_VERSION}.tar.xz"
fi
DOWNLOAD_URL="${BASE_URL}/${FILE_NAME}"
echo "Download URL for LLVM: ${DOWNLOAD_URL}"
# Step 5: Create extraction directory
echo "Creating extraction directory at ${EXTRACT_PATH}..."
mkdir -p "$EXTRACT_PATH"
if [ $? -ne 0 ]; then
echo "Error: Failed to create extraction directory."
exit 1
else
echo "Extraction directory created successfully."
fi
# Step 6: Download LLVM
echo "Downloading $FILE_NAME from $DOWNLOAD_URL..."
curl -L -o "${EXTRACT_PATH}/${FILE_NAME}" "$DOWNLOAD_URL"
if [ $? -ne 0 ]; then
echo "Error: Download failed!"
exit 1
else
echo "Download completed successfully."
fi
# Step 7: Extract LLVM
echo "Extracting $FILE_NAME to $EXTRACT_PATH..."
tar -xJf "${EXTRACT_PATH}/${FILE_NAME}" -C "$EXTRACT_PATH"
if [ $? -ne 0 ]; then
echo "Error: Extraction failed!"
exit 1
else
echo "Extraction completed successfully."
fi
# Step 8: Determine LLVM config path
LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)"
echo "LLVM config path determined as: $LLVM_CONFIG_PATH"
# Step 9: Clone and build TVM
echo "Cloning TVM repository and initializing submodules..."
# clone and build tvm
git submodule update --init --recursive
if [ -d build ]; then
rm -rf build
fi
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "Configuring TVM build with LLVM and CUDA paths..."
echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_CUDA /usr/local/cuda)" >> config.cmake
echo "Running CMake for TileLang..."
cmake ..
if [ $? -ne 0 ]; then
echo "Error: CMake configuration failed."
exit 1
fi
echo "Building TileLang with make..."
make -j
if [ $? -ne 0 ]; then
echo "Error: TileLang build failed."
exit 1
else
echo "TileLang build completed successfully."
fi
cd ../../..
# Step 11: Set environment variables
echo "Configuring environment variables for TVM..."
echo "export PYTHONPATH=$(pwd):\$PYTHONPATH" >> ~/.bashrc
echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
# Step 12: Source .bashrc to apply changes
echo "Applying environment changes by sourcing .bashrc..."
source ~/.bashrc
if [ $? -ne 0 ]; then
echo "Error: Failed to source .bashrc."
exit 1
else
echo "Environment configured successfully."
fi
echo "Installation script completed successfully."
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# install requirements
pip install -r requirements.txt
# determine if root
USER_IS_ROOT=false
if [ "$EUID" -eq 0 ]; then
USER_IS_ROOT=true
fi
if $USER_IS_ROOT; then
# Fetch the GPG key for the LLVM repository and add it to the trusted keys
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
# Check if the repository is already present in the sources.list
if ! grep -q "http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" /etc/apt/sources.list; then
# Add the LLVM repository to sources.list
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
else
# Print a message if the repository is already added
echo "The repository is already added."
fi
# Update package lists and install llvm-16
apt-get update
apt-get install -y llvm-16
else
# Fetch the GPG key for the LLVM repository and add it to the trusted keys using sudo
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
# Check if the repository is already present in the sources.list
if ! grep -q "http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" /etc/apt/sources.list; then
# Add the LLVM repository to sources.list using sudo
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee -a /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee -a /etc/apt/sources.list
else
# Print a message if the repository is already added
echo "The repository is already added."
fi
# Update package lists and install llvm-16 using sudo
sudo apt-get update
sudo apt-get install -y llvm-16
fi
# Step 9: Clone and build TVM
echo "Cloning TVM repository and initializing submodules..."
# clone and build tvm
git submodule update --init --recursive
if [ -d build ]; then
rm -rf build
fi
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "Configuring TVM build with LLVM and CUDA paths..."
echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_ROCM /opt/rocm)" >> config.cmake
echo "Running CMake for TileLang..."
cmake ..
if [ $? -ne 0 ]; then
echo "Error: CMake configuration failed."
exit 1
fi
echo "Building TileLang with make..."
make -j
if [ $? -ne 0 ]; then
echo "Error: TileLang build failed."
exit 1
else
echo "TileLang build completed successfully."
fi
cd ../../..
# Define the lines to be added
TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm"
TILELANG_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH"
CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID"
# Check and add the first line if not already present
if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then
echo "$TVM_HOME_ENV" >> ~/.bashrc
echo "Added TVM_HOME to ~/.bashrc"
else
echo "TVM_HOME is already set in ~/.bashrc"
fi
# Check and add the second line if not already present
if ! grep -qxF "$TILELANG_PYPATH_ENV" ~/.bashrc; then
echo "$TILELANG_PYPATH_ENV" >> ~/.bashrc
echo "Added PYTHONPATH to ~/.bashrc"
else
echo "PYTHONPATH is already set in ~/.bashrc"
fi
# Check and add the third line if not already present
if ! grep -qxF "$CUDA_DEVICE_ORDER_ENV" ~/.bashrc; then
echo "$CUDA_DEVICE_ORDER_ENV" >> ~/.bashrc
echo "Added CUDA_DEVICE_ORDER to ~/.bashrc"
else
echo "CUDA_DEVICE_ORDER is already set in ~/.bashrc"
fi
# Reload ~/.bashrc to apply the changes
source ~/.bashrc
echo "Installation script completed successfully."
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
echo "Add MIT license boilerplate..."
PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# TO source code root
pushd "${PWD}/../../" > /dev/null
EXITCODE=0
for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \
'*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -and \( -name '*.cpp' -or -name '*.h*' -or -name '*.cu' -or -name '*.in' \) ); do
sed -i '/\/\/\s*Microsoft\s*(c)/Id' ${SRC_FILE}
if !(grep -q "Copyright (c) Microsoft Corporation." "${SRC_FILE}"); then
cat maint/scripts/mit_liscense1.txt ${SRC_FILE} > ${SRC_FILE}.new
mv ${SRC_FILE}.new ${SRC_FILE}
fi
done
for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \
'*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cmake' \
-or -name '*.py' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do
sed -i '/\#\s*Microsoft\s*(c)/Id' ${SRC_FILE}
if !(grep -q "Copyright (c) Microsoft Corporation" "${SRC_FILE}"); then
cat maint/scripts/mit_liscense2.txt ${SRC_FILE} > ${SRC_FILE}.new
mv ${SRC_FILE}.new ${SRC_FILE}
fi
done
for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name \
'*apply_mit_liscense.sh' -not -name '*check_mit_liscense.sh' -name '*.sh' ); do
sed -i '/\#\s*Microsoft\s*(c)/Id' ${SRC_FILE}
if !(grep -q "Copyright (c) Microsoft Corporation" "${SRC_FILE}"); then
line=$(head -n 1 ${SRC_FILE})
if [[ $line == "#!/bin/bash"* ]]; then
(echo ${line}; echo ''; cat maint/scripts/mit_liscense2.txt; echo "$(tail -n +2 "${SRC_FILE}")" ) > ${SRC_FILE}.new
else
cat maint/scripts/mit_liscense2.txt ${SRC_FILE} > ${SRC_FILE}.new
fi
mv ${SRC_FILE}.new ${SRC_FILE}
fi
done
echo "Done."
popd > /dev/null
exit $EXITCODE
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
echo "Check MIT License boilerplate..."
PWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# To source code root
pushd "${PWD}/../../" > /dev/null
EXITCODE=0
for SRC_FILE in $(find . -path './3rdparty' -prune -false -o -path './build' -prune -false -o -type f -not -name '*apply_mit_license.sh' \
-not -name '*check_mit_license.sh' -and \( -name 'CMakeLists.txt' -or -name '*.cpp' -or -name '*.cu' -or -name '*.h' -or -name '*.hpp' \
-or -name '*.py' -or -name '*.sh' -or -name '*.dockerfile' -or -name '*.yaml' \) ); do
# Skip files that already contain the Apache License
if grep -q "Apache License" "${SRC_FILE}"; then
continue
fi
if !(grep -q "Copyright (c) Microsoft Corporation." "${SRC_FILE}") || !(grep -q "Licensed under the MIT License." "${SRC_FILE}") \
|| (grep -q -i -P "Microsoft( |)\(c\)" "${SRC_FILE}"); then
echo "[ERROR] Require: MIT License boilerplate" "${SRC_FILE}"
EXITCODE=1
fi
done
echo "Done."
popd > /dev/null
exit $EXITCODE
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# if dist and build directories exist, remove them
if [ -d dist ]; then
rm -r dist
fi
python -m build --wheel -o dist
if [ $? -ne 0 ]; then
echo "Error: Failed to build the wheel."
exit 1
else
echo "Wheel built successfully."
fi
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# if dist and build directories exist, remove them
if [ -d dist ]; then
rm -r dist
fi
if [ -d build ]; then
rm -r build
fi
PYPI_BUILD=TRUE python setup.py bdist_wheel --plat-name=manylinux1_x86_64
[build-system]
requires = [
"cmake>=3.26",
"packaging",
"setuptools>=61",
"setuptools-scm>=8.0",
"wheel",
]
build-backend = "setuptools.build_meta"
[tool.yapf]
based_on_style = "yapf"
column_limit = 100
indent_width = 4
[tool.codespell]
ignore-words-list = "nd, te, ist, LOD, offen"
skip = [
"build",
"3rdparty",
"dist",
".venv"
]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
# "UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
# "I",
]
ignore = [
# Module level import not at top of file
"E402",
# star imports
"F405", "F403",
# ambiguous name
"E741",
# line too long
"E501",
# key in dict.keys()
"SIM118",
# memory leaks
"B019",
# No such file or directory
"E902",
]
# formatting
yapf==0.40.2
toml==0.10.2
tomli==2.0.1
ruff==0.6.5
codespell==2.3.0
clang-format==15.0.7
# build requirements
cmake>=3.26
# runtime requirements
cffi
cpplint
Cython
decorator
docutils
dtlib
numpy>=1.23.5
pytest>=6.2.4
pytest_xdist>=2.2.1
packaging>=21.0
PyYAML
tqdm>=4.62.3
typing_extensions>=4.10.0
requests
attrs
cloudpickle
ml_dtypes
psutil
scipy
tornado
torch
thefuzz
tabulate
wheel
setuptools
# formatting
yapf==0.40.2
toml==0.10.2
tomli==2.0.1
ruff==0.6.5
codespell==2.3.0
clang-format==15.0.7
# build requirements
cmake>=3.26
# runtime requirements
cffi
cpplint
Cython
decorator
docutils
dtlib
numpy>=1.23.5
pytest>=6.2.4
pytest_xdist>=2.2.1
packaging>=21.0
PyYAML
tqdm>=4.62.3
typing_extensions>=4.10.0
requests
attrs
cloudpickle
ml_dtypes
psutil
scipy
tornado
torch
thefuzz
tabulate
wheel
setuptools
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import subprocess
import shutil
from setuptools import setup, find_packages, Extension
from setuptools.command.build_py import build_py
from setuptools.command.sdist import sdist
import distutils.dir_util
from typing import List
import re
import tarfile
from io import BytesIO
import os
import sys
import urllib.request
from distutils.version import LooseVersion
import platform
import multiprocessing
from setuptools.command.build_ext import build_ext
# Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
PACKAGE_NAME = "tilelang"
ROOT_DIR = os.path.dirname(__file__)
# TileLang only supports Linux platform
assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)."
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def get_requirements(file_path: str = "requirements.txt") -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path(file_path)) as f:
requirements = f.read().strip().split("\n")
return requirements
def find_version(version_file_path: str) -> str:
"""Extract version information from the given filepath.
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
# Read and store the version information from the VERSION file
# Use 'strip()' to remove any leading/trailing whitespace or newline characters
if not os.path.exists(version_file_path):
raise FileNotFoundError(f"Version file not found at {version_file_path}")
with open(version_file_path, "r") as version_file:
version = version_file.read().strip()
return version
def get_nvcc_cuda_version():
"""Get the CUDA version from nvcc.
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output(["nvcc", "-V"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = LooseVersion(output[release_idx].split(",")[0])
return nvcc_cuda_version
def get_tilelang_version(with_cuda=True, with_system_info=True) -> str:
version = find_version(get_path(".", "VERSION"))
local_version_parts = []
if with_system_info:
local_version_parts.append(get_system_info().replace("-", "."))
if with_cuda:
cuda_version = str(get_nvcc_cuda_version())
cuda_version_str = cuda_version.replace(".", "")[:3]
local_version_parts.append(f"cu{cuda_version_str}")
if local_version_parts:
version += f"+{'.'.join(local_version_parts)}"
return version
def get_system_info():
system = platform.system().lower()
if system == "linux":
try:
with open("/etc/os-release") as f:
os_release = f.read()
version_id_match = re.search(r'VERSION_ID="(\d+\.\d+)"', os_release)
if version_id_match:
version_id = version_id_match.group(1)
distro = "ubuntu"
return f"{distro}-{version_id}"
except FileNotFoundError:
pass
return system
def read_readme() -> str:
"""Read the README file if present."""
p = get_path("README.md")
if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
else:
return ""
def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"):
"""
Downloads and extracts the specified version of LLVM for the given platform.
Args:
version (str): The version of LLVM to download.
is_aarch64 (bool): True if the target platform is aarch64, False otherwise.
extract_path (str): The directory path where the archive will be extracted.
Returns:
str: The path where the LLVM archive was extracted.
"""
ubuntu_version = "16.04"
if version >= "16.0.0":
ubuntu_version = "20.04"
elif version >= "13.0.0":
ubuntu_version = "18.04"
base_url = (f"https://github.com/llvm/llvm-project/releases/download/llvmorg-{version}")
file_name = f"clang+llvm-{version}-{'aarch64-linux-gnu' if is_aarch64 else f'x86_64-linux-gnu-ubuntu-{ubuntu_version}'}.tar.xz"
download_url = f"{base_url}/{file_name}"
# Download the file
print(f"Downloading {file_name} from {download_url}")
with urllib.request.urlopen(download_url) as response:
if response.status != 200:
raise Exception(f"Download failed with status code {response.status}")
file_content = response.read()
# Ensure the extract path exists
os.makedirs(extract_path, exist_ok=True)
# if the file already exists, remove it
if os.path.exists(os.path.join(extract_path, file_name)):
os.remove(os.path.join(extract_path, file_name))
# Extract the file
print(f"Extracting {file_name} to {extract_path}")
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
tar.extractall(path=extract_path)
print("Download and extraction completed successfully.")
return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", "")))
package_data = {
"tilelang": ["py.typed"],
}
LLVM_VERSION = "10.0.1"
IS_AARCH64 = False # Set to True if on an aarch64 platform
EXTRACT_PATH = "3rdparty" # Default extraction path
def update_submodules():
"""Updates git submodules."""
try:
subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"])
except subprocess.CalledProcessError as error:
raise RuntimeError("Failed to update submodules") from error
def build_csrc(llvm_config_path):
"""Configures and builds TVM."""
if not os.path.exists("build"):
os.makedirs("build")
os.chdir("build")
# Copy the config.cmake as a baseline
if not os.path.exists("config.cmake"):
shutil.copy("../3rdparty/tvm/cmake/config.cmake", "config.cmake")
# Set LLVM path and enable CUDA in config.cmake
with open("config.cmake", "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
config_file.write("set(USE_CUDA /usr/local/cuda)\n")
# Run CMake and make
try:
subprocess.check_call(["cmake", ".."])
num_jobs = multiprocessing.cpu_count()
subprocess.check_call(["make", f"-j{num_jobs}"])
except subprocess.CalledProcessError as error:
raise RuntimeError("Failed to build TileLang C Source") from error
def setup_llvm_for_tvm():
"""Downloads and extracts LLVM, then configures TVM to use it."""
# Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script
extract_path = download_and_extract_llvm(LLVM_VERSION, IS_AARCH64, EXTRACT_PATH)
llvm_config_path = os.path.join(extract_path, "bin", "llvm-config")
return extract_path, llvm_config_path
class TileLangBuilPydCommand(build_py):
"""Customized setuptools install command - builds TVM after setting up LLVM."""
def run(self):
build_py.run(self)
self.run_command("build_ext")
build_ext_cmd = self.get_finalized_command("build_ext")
build_temp_dir = build_ext_cmd.build_temp
ext_modules = build_ext_cmd.extensions # 列出所有扩展模块
for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name) # 获取扩展模块的完整路径
print(f"Extension {ext.name} output directory: {extdir}")
ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
TILELANG_SRC = [
"src/tl_templates",
]
for item in TILELANG_SRC:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
f"{ext_output_dir}/libtvm_runtime.so",
f"{ext_output_dir}/libtvm.so",
f"{ext_output_dir}/libtilelang.so",
f"{ext_output_dir}/libtilelang_module.so",
]
for item in TVM_PREBUILD_ITEMS:
source_lib_file = os.path.join(ROOT_DIR, item)
# only copy the file
file_name = os.path.basename(item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, file_name)
target_dir = os.path.dirname(target_dir)
target_dir = os.path.join(target_dir, "lib")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
if os.path.exists(source_lib_file):
shutil.copy2(source_lib_file, target_dir)
# remove the original file
os.remove(source_lib_file)
else:
print(f"INFO: {source_lib_file} does not exist.")
TVM_CONFIG_ITEMS = [
f"{build_temp_dir}/config.cmake",
]
for item in TVM_CONFIG_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
# only copy the file
file_name = os.path.basename(item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, file_name)
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
if os.path.exists(source_dir):
shutil.copy2(source_dir, target_dir)
else:
print(f"INFO: {source_dir} does not exist.")
TVM_PACAKGE_ITEMS = [
"3rdparty/tvm/src",
"3rdparty/tvm/python",
"3rdparty/tvm/licenses",
"3rdparty/tvm/conftest.py",
"3rdparty/tvm/CONTRIBUTORS.md",
"3rdparty/tvm/KEYS",
"3rdparty/tvm/LICENSE",
"3rdparty/tvm/README.md",
"3rdparty/tvm/mypy.ini",
"3rdparty/tvm/pyproject.toml",
"3rdparty/tvm/version.py",
]
for item in TVM_PACAKGE_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# Copy CUTLASS to the package directory
CUTLASS_PREBUILD_ITEMS = [
"3rdparty/cutlass",
]
for item in CUTLASS_PREBUILD_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy compoable kernel to the package directory
CK_PREBUILD_ITEMS = [
"3rdparty/composable_kernel",
]
for item in CK_PREBUILD_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy compoable kernel to the package directory
TL_CONFIG_ITEMS = ["CMakeLists.txt", "VERSION", "README.md", "LICENSE"]
for item in TL_CONFIG_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
class TileLangSdistCommand(sdist):
"""Customized setuptools sdist command - includes the pyproject.toml file."""
def make_distribution(self):
self.distribution.metadata.name = PACKAGE_NAME
self.distribution.metadata.version = get_tilelang_version(
with_cuda=False, with_system_info=False)
super().make_distribution()
class CMakeExtension(Extension):
"""
A specialized setuptools Extension class for building a CMake project.
:param name: Name of the extension module.
:param sourcedir: Directory containing the top-level CMakeLists.txt.
"""
def __init__(self, name, sourcedir=""):
# We pass an empty 'sources' list because
# the actual build is handled by CMake, not setuptools.
super().__init__(name=name, sources=[])
# Convert the source directory to an absolute path
# so that CMake can correctly locate the CMakeLists.txt.
self.sourcedir = os.path.abspath(sourcedir)
class CMakeBuild(build_ext):
"""
Custom build_ext command for CMake-based projects.
This class overrides the 'run' method to ensure that CMake is available,
and then iterates over all extensions defined as CMakeExtension,
delegating the actual build logic to 'build_cmake'.
"""
def run(self):
# Check if CMake is installed and accessible by attempting to run 'cmake --version'.
try:
subprocess.check_output(["cmake", "--version"])
except OSError as e:
# If CMake is not found, raise an error.
raise RuntimeError("CMake must be installed to build the following extensions") from e
update_submodules()
# Build each extension (of type CMakeExtension) using our custom method.
for ext in self.extensions:
self.build_cmake(ext)
def build_cmake(self, ext):
"""
Build a single CMake-based extension.
:param ext: The extension (an instance of CMakeExtension).
"""
# Setup LLVM for TVM and retrieve the path to llvm-config.
# We assume the function returns (_, llvm_config_path).
_, llvm_config_path = setup_llvm_for_tvm()
# Determine the directory where the final .so or .pyd library should go.
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# Prepare arguments for the CMake configuration step.
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
]
# Create the temporary build directory (if it doesn't exist).
build_temp = os.path.abspath(self.build_temp)
os.makedirs(build_temp, exist_ok=True)
# Copy the default 'config.cmake' from the source tree into our build directory.
src_config_cmake = os.path.join(ext.sourcedir, "3rdparty", "tvm", "cmake", "config.cmake")
dst_config_cmake = os.path.join(build_temp, "config.cmake")
shutil.copy(src_config_cmake, dst_config_cmake)
# Append some configuration variables to 'config.cmake'.
# Here, we set USE_LLVM and USE_CUDA, for example.
with open(dst_config_cmake, "a") as config_file:
config_file.write(f"set(USE_LLVM {llvm_config_path})\n")
config_file.write("set(USE_CUDA /usr/local/cuda)\n")
# Run CMake to configure the project with the given arguments.
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
# Build the project in "Release" mode with all available CPU cores ("-j").
subprocess.check_call(["cmake", "--build", ".", "--config", "Release", "-j"],
cwd=build_temp)
setup(
name=PACKAGE_NAME,
version=(get_tilelang_version(with_cuda=False, with_system_info=False)
if PYPI_BUILD else get_tilelang_version()),
packages=find_packages(where="."),
package_dir={"": "."},
author="Microsoft Research",
description="A tile level programming language to generate high performance code.",
long_description=read_readme(),
long_description_content_type="text/markdown",
platforms=[
"Environment :: GPU :: NVIDIA CUDA",
"Operating System :: POSIX :: Linux",
],
license="MIT",
keywords="BLAS, CUDA, HIP, Code Generation, TVM",
url="https://github.com/microsoft/TileLang",
classifiers=[
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
],
python_requires=">=3.8",
install_requires=get_requirements(),
package_data=package_data,
include_package_data=False,
ext_modules=[CMakeExtension("TileLangCXX", sourcedir=".")],
cmdclass={
"build_py": TileLangBuilPydCommand,
"sdist": TileLangSdistCommand,
"build_ext": CMakeBuild,
},
)
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/ir.cc
* \brief Extension for the tvm script frontend.
*
*/
#include <tvm/script/ir_builder/tir/ir.h>
namespace tvm {
namespace tl {
using namespace script::ir_builder::tir;
ForFrame ParallelFor(Array<PrimExpr> extents,
Map<String, ObjectRef> annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
for (const auto &extent : extents) {
DataType dtype = extent.dtype();
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
body =
For(var, dom->min, dom->extent, ForKind::kParallel, std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/annotations);
}
return body;
};
return ForFrame(n);
}
ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
Array<PrimExpr> order, Array<PrimExpr> stages,
Array<Array<PrimExpr>> sync,
Array<Array<PrimExpr>> groups) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
DataType dtype = stop.dtype();
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(start, stop));
n->f_make_for_loop = [=](Array<Var> vars, Array<Range> doms,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
ICHECK(n == 1);
Map<String, ObjectRef> anno;
if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages));
if (order.size() > 0)
anno.Set("tl_pipeline_order", order);
if (stages.size() > 0)
anno.Set("tl_pipeline_stage", stages);
if (sync.size() > 0)
anno.Set("tl_pipeline_sync", sync);
if (groups.size() > 0)
anno.Set("tl_pipeline_group", groups);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial,
std::move(body),
/*thread_binding=*/NullOpt, /*annotations=*/anno);
return body;
};
return ForFrame(n);
}
/*!
* \brief A frame that represents a kernel launch.
*
* \sa KernelLaunchFrameNode
*/
class KernelLaunchFrameNode : public TIRFrameNode {
public:
Array<TIRFrame> frames;
void VisitAttrs(tvm::AttrVisitor *v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("frames", &frames);
}
static constexpr const char *_type_key = "tl.KernelLaunchFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode);
public:
TVM_DLL void EnterWithScope() final {
for (auto frame = frames.begin(); frame != frames.end(); ++frame)
(*frame)->EnterWithScope();
}
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
TVM_DLL void ExitWithScope() final {
for (auto frame = frames.rbegin(); frame != frames.rend(); ++frame)
(*frame)->ExitWithScope();
}
};
/*!
* \brief Managed reference to KernelLaunchFrameNode.
*
* \sa KernelLaunchFrameNode
*/
class KernelLaunchFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame,
KernelLaunchFrameNode);
};
KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
Array<PrimExpr> block_size,
Map<String, ObjectRef> attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0)
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
if (grid_size.size() > 1)
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1]));
if (grid_size.size() > 2)
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2]));
if (block_size.defined()) {
ICHECK(block_size.size() <= 3);
if (block_size.size() > 0)
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0]));
if (block_size.size() > 1)
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1]));
if (block_size.size() > 2)
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2]));
} else {
n->frames.push_back(Block(""));
}
if (attrs.defined()) {
auto empty_block = Block("");
empty_block->annotations = attrs;
n->frames.push_back(empty_block);
} else {
n->frames.push_back(Block(""));
}
return KernelLaunchFrame(n);
}
TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor);
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor);
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch);
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file layout/gemm_layouts.cc
* \brief Define Layout used in MMA and other operations.
*
*/
#include <tvm/tir/stmt_functor.h>
#include <cmath>
#include "layout.h"
namespace tvm {
namespace tl {
static IterVar make_itervar(std::string name, PrimExpr dom) {
Var var = Var(name);
return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}
Fragment makeGemmFragment8x8() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 2) + 4 * i;
PrimExpr index = FloorMod(j->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}
/*
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
--detail-instruction
*/
Fragment makeGemmFragmentAB16x16CDNA() {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentAB16x16CDNATransposed() {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(i->var, 4) + j;
PrimExpr index = FloorMod(i->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentC16x16CDNA() {
IterVar i = make_itervar("i", 16);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = 16 * FloorDiv(j->var, 4) + i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
PrimExpr index = FloorMod(i->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x16() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, const int warp_m,
const int warp_n) {
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(warp_n % 16 == 0);
auto base_layout = makeGemmFragment8x8();
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout = warp_layout->Repeat({warp_m / 8, warp_n / 8}, false, false);
return block_layout;
}
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) return makeGemmFragmentC_F64(block_m, block_n, warp_m, warp_n);
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
return block_layout;
}
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout;
}
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size) {
ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0);
auto warp_layout =
makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false, false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); // 16*Y x N (Y warp)
return block_layout->Repeat({warp_m / 16, 1}, false, false);
}
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, const int element_size) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0);
// Only support 8-bit and 16-bit
ICHECK(element_size == 8 || element_size == 16);
if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n);
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
return block_layout;
} else if (element_size == 16) {
auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)->Replicate(block_n / warp_n);
auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout;
} else {
ICHECK(0);
return Fragment();
}
}
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0);
if (transposed) {
auto base_layout = makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n);
return block_layout;
} else {
auto base_layout = makeGemmFragmentAB16x16CDNA()->Repeat({1, 1}, false, false);
auto warp_layout = base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, 1}, true, true)->Replicate(block_n / warp_n);
return block_layout;
}
}
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n) {
// transposed
ICHECK(warp_n % 8 == 0);
ICHECK(block_k % 16 == 0);
auto base_layout = makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto warp_layout = base_layout->Replicate(block_m / warp_m)->Repeat({1, block_n / warp_n}, true);
auto block_layout = warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
return block_layout;
}
Fragment makeGemmFragment32x32(int element_size) {
IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 32);
IterVar rep = make_itervar("rep", 1);
ICHECK(element_size == 16 || element_size == 32);
if (element_size == 16) {
PrimExpr thd = FloorMod(i, 4) + FloorDiv(FloorMod(i, 16), 8) * 4 +
FloorDiv(FloorMod(j, 16), 8) * 8 + FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 4) + FloorDiv(j, 16) * 4 + FloorDiv(FloorMod(i, 8), 4) * 8 +
FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep);
} else {
PrimExpr thd = FloorMod(i, 2) + 2 * FloorDiv(FloorMod(j, 4), 2) +
FloorDiv(FloorMod(i, 16), 8) * 4 + FloorDiv(FloorMod(j, 16), 8) * 8 +
FloorDiv(i, 16) * 16;
PrimExpr idx = FloorMod(j, 2) + 2 * FloorDiv(FloorMod(i, 4), 2) + FloorDiv(j, 16) * 4 +
FloorDiv(FloorMod(i, 8), 4) * 8 + FloorDiv(FloorMod(j, 8), 4) * 16;
return Fragment({i, j}, {idx}, thd, rep);
}
}
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m,
const int warp_n, int element_size) {
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 32 == 0);
ICHECK(warp_n % 32 == 0);
auto base_layout = makeGemmFragment32x32(element_size);
auto warp_layout = base_layout->Repeat({warp_m / 32, warp_n / 32}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true);
return block_layout;
}
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 32 == 0);
ICHECK(block_k % 4 == 0);
// this is a special case
IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 4);
IterVar rep = make_itervar("rep", 2);
PrimExpr thd = FloorDiv(FloorMod(i, 16), 8) * 4 + 16 * FloorDiv(i, 16) + FloorMod(i, 4) + 8 * rep;
PrimExpr idx = j + FloorDiv(FloorMod(i, 8), 4) * 4;
Fragment base_layout = Fragment({i, j}, {idx}, thd, rep);
auto warp_layout = base_layout->Repeat({warp_m / 32, block_k / 4}, false, false);
auto block_layout = warp_layout->Replicate(block_n / warp_n)->Repeat({block_m / warp_m, 1}, true);
return block_layout;
}
PrimExpr xor2x2(const PrimExpr& i, const PrimExpr& j) { return FloorMod(i + j, 2); }
PrimExpr xor4x4(const PrimExpr& i, const PrimExpr& j) {
PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2);
PrimExpr j1 = FloorDiv(j, 2);
return 2 * xor2x2(i1, j1) + xor2x2(i0, j0);
}
PrimExpr xor8x8(const PrimExpr& i, const PrimExpr j) {
PrimExpr i0 = FloorMod(i, 2);
PrimExpr j0 = FloorMod(j, 2);
PrimExpr i1 = FloorDiv(i, 2);
PrimExpr j1 = FloorDiv(j, 2);
return 2 * xor4x4(i1, j1) + xor2x2(i0, j0);
}
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
// Swizzle 2 bit
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0);
ICHECK(continuous % (vector_size * 4) == 0);
PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 4);
PrimExpr c = FloorMod(FloorDiv(j, vector_size), 4);
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2));
PrimExpr index = vec + (c_swizzle + s * 4) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
// Swizzle 3 bit
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0);
ICHECK(continuous % (vector_size * 8) == 0);
PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8);
PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8);
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor8x8(c, s);
PrimExpr index = vec + (c_swizzle + s * 8) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
// Detail implementation please ref to bitblas::tl::mfma_layout::make_mfma_swizzle_layout
Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, int kPack=1) {
const int numBanks = 32;
const int bankBitWidth = 32;
const int SIMDWidth = 16;
const int vecSize = 4 * kPack;
const int innerDimLength = continuous;
const int typeWidthInBit = element_size;
const int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
const int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
const int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
IterVar row = make_itervar("row", stride);
IterVar col = make_itervar("col", continuous);
PrimExpr phase = FloorMod(row / perPhase, maxPhase);
PrimExpr colOffSwizzled = ((col / vecSize) ^ phase) * vecSize;
PrimExpr colOffOrdered = FloorMod(col, vecSize);
PrimExpr colOff = colOffSwizzled + colOffOrdered;
return Layout(Array{row, col}, {row, colOff});
}
Layout makeGemmABLayoutF64_Kinner(int stride, int continuous) {
// Swizzle<2, 0, 4>
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
PrimExpr tc = FloorDiv(j, 16);
PrimExpr ts = FloorDiv(i, 4);
PrimExpr c = FloorMod(j, 16);
PrimExpr s = FloorMod(i, 4);
PrimExpr swizzled_c = FloorDiv(c, 4) * 4 + xor4x4(FloorMod(c, 4), s);
PrimExpr index = swizzled_c + s * 16;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) {
// Swizzle<2, 2, 2>
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
PrimExpr tc = FloorDiv(j, 16);
PrimExpr ts = FloorDiv(i, 4);
PrimExpr c = FloorMod(j, 16);
PrimExpr s = FloorMod(i, 4);
PrimExpr swizzled_c = FloorMod(c, 4) + xor4x4(FloorDiv(c, 4), s) * 4;
PrimExpr index = swizzled_c + s * 16;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
// The Default Layout for Tensor Access
Layout makeGemmLayoutLinear(int stride, int continuous) {
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
return Layout(Array{i, j}, {i * continuous + j});
}
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
int padded = continuous;
// Add 128 bits padding when the last dim is a multiple of 256 bits
if ((element_size * continuous) % 256 == 0) padded += 128 / element_size;
return Layout(Array{i, j}, {i * padded + j});
}
Layout MakeGemmVoltaABLayoutCrosswise(int stride, int continuous) {
ICHECK(stride % 32 == 0 && continuous % 32 == 0);
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
PrimExpr vec_contiguous_idx = FloorDiv(j, 4);
PrimExpr vec_strided_within_tile = FloorMod(vec_contiguous_idx, 8);
PrimExpr bit2 = FloorMod(FloorDiv(FloorMod(i, 32), 16) + FloorDiv(FloorMod(i, 16), 8) +
FloorDiv(vec_strided_within_tile, 4),
2);
PrimExpr bit1 =
xor2x2(FloorDiv(FloorMod(i, 8), 4), FloorDiv(FloorMod(vec_strided_within_tile, 4), 2));
PrimExpr permuted_vec_contiguous = FloorDiv(i, 16) * 16 + FloorMod(i, 4) * 4 + bit2 * 2 + bit1;
PrimExpr offset = FloorMod(j, 4) + permuted_vec_contiguous * 4 + vec_contiguous_idx * stride * 4;
return Layout(Array{i, j}, {offset});
}
Layout MakeGemmVoltaALayoutCongruous(int stride, int continuous) {
ICHECK(stride % 4 == 0 && continuous % 64 == 0);
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
PrimExpr vec_strided_idx = i;
PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);
PrimExpr permuted_strided_within_tile = FloorDiv(tile_contiguous_residual, 2);
PrimExpr permuted_contiguous_within_tile =
FloorMod(tile_contiguous_residual, 2) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset});
}
Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
ICHECK(stride % 4 == 0 && continuous % 64 == 0);
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
PrimExpr vec_contiguous_idx = FloorDiv(j, 8);
PrimExpr vec_strided_idx = i;
PrimExpr tile_contiguous_idx = FloorDiv(vec_contiguous_idx, 8);
PrimExpr tile_strided_idx = FloorDiv(vec_strided_idx, 4);
PrimExpr tile_contiguous_residual = FloorMod(vec_contiguous_idx, 8);
PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, 4);
PrimExpr permuted_strided_within_tile = FloorMod(tile_contiguous_residual, 4);
PrimExpr permuted_contiguous_within_tile =
FloorDiv(tile_contiguous_residual, 4) * 4 +
xor4x4(tile_strided_residual, permuted_strided_within_tile);
PrimExpr element_strided = permuted_strided_within_tile + tile_strided_idx * 4;
PrimExpr element_contiguous =
FloorMod(j, 8) + (permuted_contiguous_within_tile + tile_contiguous_idx * 8) * 8;
PrimExpr offset = element_strided * continuous + element_contiguous;
return Layout(Array{i, j}, {offset});
}
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor) {
if (kfactor == 2) return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0) return MakeGemmVoltaALayoutCongruous(stride, continuous);
if (!is_a && continuous % 64 == 0) return MakeGemmVoltaBLayoutCongruous(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, 16);
}
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor) {
if (element_size == 64) {
if (kfactor == 1 && continuous % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(stride, continuous);
if (kfactor == 2 && continuous % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(stride, continuous);
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(stride, continuous, element_size);
else if (continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(stride, continuous, element_size);
else if (continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(stride, continuous, element_size);
else {
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
}
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kPack) {
int vector_size = 128 / element_size;
if (continuous % (vector_size * 4) == 0)
return makeMatrixCoreSwizzleLayout(stride, continuous, element_size, kPack);
else {
return makeGemmABLayoutPadded(stride, continuous, element_size);
}
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file layout/layout.cc
*
*/
#include "layout.h"
#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include "arith/pattern_match.h"
#include "utils.h"
namespace tvm {
namespace tl {
using namespace tir;
static Var getPlaceholder(const std::string& s) {
static std::unordered_map<std::string, Var> map;
if (map.find(s) == map.end()) {
map[s] = Var(s);
}
return map[s];
}
Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
Var InputPlaceholder(size_t idx) { return getPlaceholder(std::string{'_', char('i' + idx)}); }
Map<Var, Range> LayoutNode::getVarMap() const {
Map<Var, Range> map;
for (size_t i = 0; i < InputDim(); i++) {
map.Set(InputPlaceholder(i), {0, input_size_[i]});
}
return map;
}
Map<Var, Range> FragmentNode::getVarMap() const {
auto map = LayoutNode::getVarMap();
map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
return map;
}
LayoutNode::LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
input_size_ = input_size;
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
}
Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size;
for (size_t i = 0; i < forward_var.size(); i++) {
vmap.Set(forward_var[i]->var, InputPlaceholder(i));
CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent);
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
auto n = make_object<LayoutNode>(input_size, forward_index);
data_ = std::move(n);
}
void LayoutNode::VisitAttrs(AttrVisitor* v) {
v->Visit("input_size", &input_size_);
v->Visit("forward_index", &forward_index_);
}
void LayoutNode::UpdateAnalyzer(arith::Analyzer* analyzer) const {
for (const auto& [var, dom] : getVarMap()) {
analyzer->Bind(var, dom);
}
}
Array<PrimExpr> LayoutNode::OutputShape() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
for (size_t i = 0; i < ret.size(); i++) {
auto ist = analyzer.int_set(forward_index_[i] + 1);
if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
// X-OR Expression
ret.Set(i, input_size_[i]);
} else {
CHECK(is_one(ist.min())) << ist.min();
ret.Set(i, ist.max());
}
}
return ret;
}
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr>& vars) const {
if (vars.empty()) return forward_index_;
ICHECK_EQ(vars.size(), InputDim());
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]);
}
return forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
}
Fragment FragmentNode::Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread,
bool lower_dim_first) const {
ICHECK_EQ(repeats.size(), InputDim());
Array<PrimExpr> new_input_size;
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < InputDim(); i++) {
new_input_size.push_back(input_size_[i] * repeats[i]);
vmap.Set(InputPlaceholder(i), FloorMod(InputPlaceholder(i), InputShape()[i]));
}
PrimExpr repeats_index = 0, repeat_stride = 1;
if (lower_dim_first) {
for (int i = InputDim() - 1; i >= 0; i--) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i];
}
} else {
for (size_t i = 0; i < InputDim(); i++) {
repeats_index += repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
repeat_stride *= repeats[i];
}
}
if (repeat_on_thread) {
PrimExpr thread_size = ThreadExtent();
auto new_forward_index =
forward_index_.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
auto new_forward_thread = Substitute(forward_thread_, vmap) + thread_size * repeats_index;
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_,
NullOpt);
} else {
ICHECK(OutputDim() == 1);
PrimExpr frag_len = OutputShape()[0];
Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
frag_len * repeats_index};
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
return Fragment(new_input_size, new_forward_index, new_forward_thread, replicate_size_,
NullOpt);
}
}
Fragment FragmentNode::Replicate(int repeats) const {
ICHECK(repeats >= 1);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(), FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
PrimExpr new_forward_thread =
Substitute(forward_thread_, vmap) +
ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
return Fragment(input_size_, forward_index_, new_forward_thread, ReplicateExtent() * repeats,
NullOpt);
}
Fragment FragmentNode::DeReplicate() const {
ICHECK(OutputDim() == 1);
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
int factor = 1;
auto rep_size = as_const_int(ReplicateExtent());
auto idx_size = as_const_int(OutputShape()[0]);
if (rep_size && idx_size) {
factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
}
if (factor == 1) return GetRef<Fragment>(this);
Map<Var, PrimExpr> vmap;
vmap.Set(ReplicationPlaceholder(),
ReplicationPlaceholder() * factor + FloorMod(forward_index_[0], factor));
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread, int(*rep_size) / factor,
NullOpt);
}
Layout LayoutNode::Inverse() const {
arith::Analyzer analyzer;
arith::IterMapResult res = arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer);
ICHECK(res->errors.empty()) << res->errors;
auto outputs_shape = OutputShape();
Array<PrimExpr> outputs;
for (size_t i = 0; i < OutputDim(); i++) {
outputs.push_back(InputPlaceholder(i));
}
auto inv = arith::InverseAffineIterMap(res->indices, outputs);
Array<PrimExpr> backward_index;
for (size_t i = 0; i < InputDim(); i++) {
if (inv.find(InputPlaceholder(i)) != inv.end()) {
backward_index.push_back(inv[InputPlaceholder(i)]);
} else {
backward_index.push_back(0);
}
}
return Layout(outputs_shape, backward_index);
}
PrimExpr infer_fragment_index(const Map<Var, Range>& input_iters, const PrimExpr& forward_thread,
arith::Analyzer* analyzer) {
Array<arith::IterSplitExpr> splits =
DivideUnusedIterators({forward_thread}, ToIterVars(input_iters), analyzer);
Array<arith::IterSplitExpr> split_without_rep;
for (const auto& split : splits) {
CHECK(split->source->source.as<Var>());
if (split->source->source.as<Var>().value().same_as(ReplicationPlaceholder())) continue;
split_without_rep.push_back(split);
}
return MakeFlattenedExpression(split_without_rep);
}
FragmentNode::FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size) {
input_size_ = input_size;
replicate_size_ = replicate_size;
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
forward_thread_ = analyzer.Simplify(forward_thread);
if (forward_index.empty()) {
forward_index = {infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
}
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
}
Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
PrimExpr forward_thread, IterVar thread_replicate) {
Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size;
PrimExpr replicate_size = 1;
for (size_t i = 0; i < forward_var.size(); i++) {
vmap.Set(forward_var[i]->var, InputPlaceholder(i));
CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent);
}
if (thread_replicate.defined()) {
ICHECK(is_zero(thread_replicate->dom->min));
replicate_size = thread_replicate->dom->extent;
vmap.Set(thread_replicate->var, ReplicationPlaceholder());
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
forward_thread = Substitute(forward_thread, vmap);
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size);
data_ = std::move(n);
}
Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var) {
if (replicate_var.defined()) {
forward_thread =
Substitute(forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
}
auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread, replicate_size);
data_ = std::move(n);
}
void FragmentNode::VisitAttrs(tvm::AttrVisitor* v) {
LayoutNode::VisitAttrs(v);
v->Visit("forward_thread", &forward_thread_);
v->Visit("replicate_size", &replicate_size_);
}
PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
auto ist = analyzer.int_set(forward_thread_ + 1);
CHECK(is_one(ist.min()));
return ist.max();
}
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr>& vars,
const Optional<PrimExpr>& rep_var) const {
Map<Var, PrimExpr> vmap;
ICHECK_EQ(vars.size(), InputDim());
for (size_t i = 0; i < InputDim(); i++) {
vmap.Set(InputPlaceholder(i), vars[i]);
}
if (rep_var.defined()) vmap.Set(ReplicationPlaceholder(), rep_var.value());
return Substitute(forward_thread_, vmap);
}
Layout FragmentNode::Inverse() const {
auto input_size_copy = input_size_;
input_size_copy.push_back(ReplicateExtent());
auto forward_index_copy = forward_index_;
forward_index_copy.push_back(
Substitute(forward_thread_, {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
auto fwd = Layout(input_size_copy, forward_index_copy);
auto bwd = fwd->Inverse();
return bwd;
}
Fragment FragmentNode::CondenseReplicateVar() const {
arith::Analyzer analyzer;
auto input_iters = getVarMap();
input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
PrimExpr new_forward_thread;
IterVar new_thread_replicate;
std::tie(new_forward_thread, new_thread_replicate) = CompressIterator(
forward_thread_, ToIterVars(input_iters), ReplicationPlaceholder(), &analyzer);
return Fragment(input_size_, forward_index_, new_forward_thread,
new_thread_replicate->dom->extent, new_thread_replicate->var);
}
void LayoutNode::DebugOutput() const {
LOG_DEBUG << "Layout Shape: " << InputShape() << " -> " << OutputShape();
LOG_DEBUG << "Layout Index: " << forward_index_;
}
void FragmentNode::DebugOutput() const {
LOG_DEBUG << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
LOG_DEBUG << "Fragment Replicate: " << ReplicateExtent();
LOG_DEBUG << "Fragment ThreadExtent: " << ThreadExtent();
LOG_DEBUG << "Fragment Index: " << forward_index_;
LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_;
}
bool LayoutNode::SEqualReduce(const LayoutNode* other, SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_);
}
bool FragmentNode::SEqualReduce(const FragmentNode* other, SEqualReducer equal) const {
return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
equal(this->InputShape(), other->InputShape()) &&
equal(this->ThreadExtent(), other->ThreadExtent()) &&
equal(this->forward_index_, other->forward_index_) &&
equal(this->forward_thread_, other->forward_thread_);
}
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});
TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) {
return layout->InputShape();
});
TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) {
return layout->OutputShape();
});
TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) {
return layout->Inverse();
});
TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
return layout->GetForwardIndex();
});
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Fragment(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size").set_body_typed([](Fragment fragment) {
return fragment->ThreadExtent();
});
TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
return fragment->GetForwardThread();
});
TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
.set_body_typed([](Fragment fragment, Array<PrimExpr> repeats, bool repeat_on_thread,
bool lower_dim_first) {
return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
});
TVM_REGISTER_GLOBAL("tl.Fragment_replicate").set_body_typed([](Fragment fragment, int repeats) {
return fragment->Replicate(repeats);
});
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var").set_body_typed([](Fragment fragment) {
return fragment->CondenseReplicateVar();
});
TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
.set_body_typed([](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, element_size, 0);
});
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment