Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
Checks: >
clang-analyzer-*,
cppcoreguidelines-*,
modernize-*,
performance-*,
readability-*
WarningsAsErrors: '*'
HeaderFilterRegex: '^(?!.*(3rdparty|build)).*$'
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
- cron: '36 11 * * 5'
jobs:
analyze:
name: Analyze
# Runner size impacts CodeQL analysis time. To learn more, please see:
# - https://gh.io/recommended-hardware-resources-for-running-codeql
# - https://gh.io/supported-runners-and-hardware-resources
# - https://gh.io/using-larger-runners
# Consider using larger runners for possible analysis time improvements.
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
permissions:
# required for all workflows
security-events: write
# only required for workflows in private repositories
actions: read
contents: read
strategy:
fail-fast: false
matrix:
language: [ 'python' ]
# CodeQL supports [ 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' ]
# Use only 'java-kotlin' to analyze code written in Java, Kotlin or both
# Use only 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
# Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
# queries: security-extended,security-and-quality
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
# If the Autobuild fails above, remove it and uncomment the following three lines.
# modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
# - run: |
# echo "Run, Build Application using script"
# ./location_of_script_within_repo/buildscript.sh
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"
name: Dependent Bot Action
on:
pull_request:
branches: [main]
workflow_dispatch:
jobs:
bot-task:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
# Compiled Object files
*.slo
*.lo
*.o
*.obj
*.pyc
# Precompiled Headers
*.gch
*.pch
# emacs
*~
# vim
*.swp
*.swo
debug/
build/
dist/
__pycache__
nnfusion.tar.gz
# makeenv and test intermediate files
tmp/
venv/
.vscode/
.vs/
# VisualGDB files
VisualGDB/
toolchain.cmake
# docbuild artifacts
doc/sphinx/build/*
doc/doxygen/*.xml
doc/doxygen/*.html
doc/doxygen/man/*
doc/doxygen/latex/*
doc/doxygen/xml/*
doc/doxygen/html/*
# git merge
*.orig
\#*
\.#*
# idea
.idea/*
# python egg
*.egg-info
# Macos
**/.DS_Store
nnfusion_rt/
models/frozenmodels/
# log
*.log
# pkl
*.pkl_*
# .pytest_cache
.pytest_cache
# .hypothesis
.hypothesis
# .ruff_cache
.ruff_cache
# build sdist
build_sdist/
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/TileLang/cutlass
[submodule "3rdparty/tvm"]
path = 3rdparty/tvm
url = https://github.com/TileLang/tvm
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = https://github.com/ROCm/composable_kernel
Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8
Subproject commit a2954a8fdd9a73852f2c1ddea97d0e8a579cfb25
Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac
# Copyright(c) Microsoft Corporation.
# Licensed under the MIT License.
# Learn a lot from the MLC - LLM Project
# https: // github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt
cmake_minimum_required(VERSION 3.18)
project(TILE_LANG C CXX)
# Set default build type to Release if not provided
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
# Define a custom macro for globbing files with conditional CONFIGURE_DEPENDS
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.12.0")
macro(tilelang_file_glob glob variable)
file(${glob} ${variable} CONFIGURE_DEPENDS ${ARGN})
endmacro()
else()
macro(tilelang_file_glob glob variable)
file(${glob} ${variable} ${ARGN})
endmacro()
endif()
# Handle TVM prebuild path or use default configuration
if(DEFINED TVM_PREBUILD_PATH)
message(STATUS "TVM_PREBUILD_PATH: ${TVM_PREBUILD_PATH}")
if(EXISTS ${TVM_PREBUILD_PATH}/config.cmake)
include(${TVM_PREBUILD_PATH}/config.cmake)
endif()
else()
if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)
include(${CMAKE_BINARY_DIR}/config.cmake)
elseif(EXISTS ${CMAKE_SOURCE_DIR}/config.cmake)
include(${CMAKE_SOURCE_DIR}/config.cmake)
endif()
# Set default build type to RelWithDebInfo if not provided
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Build type" FORCE)
message(STATUS "Setting default build type to ${CMAKE_BUILD_TYPE}")
endif()
endif()
# include cmake modules
include(CheckCXXCompilerFlag)
# Enable static runtime build if required
if(TILE_LANG_INSTALL_STATIC_LIB)
set(BUILD_STATIC_RUNTIME ON)
endif()
# Enforce CUDA standard
if(USE_CUDA)
set(CMAKE_CUDA_STANDARD 17)
endif()
# Enforce HIP standard
if(USE_ROCM)
set(CMAKE_HIP_STANDARD 17)
check_cxx_compiler_flag("-std=c++17" SUPPORT_CXX17)
set(CMAKE_CXX_FLAGS "-D__HIP_PLATFORM_AMD__ ${CMAKE_CXX_FLAGS}")
endif()
# Enforce C++ standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# Locate TVM source directory
if(NOT DEFINED TVM_SOURCE_DIR)
if(DEFINED ENV{TVM_SOURCE_DIR})
set(TVM_SOURCE_DIR "$ENV{TVM_SOURCE_DIR}")
else()
set(TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}/3rdparty/tvm)
endif()
endif()
# Handle TVM prebuild or build TVM from source
if(DEFINED TVM_PREBUILD_PATH)
message(STATUS "Using prebuilt TVM from ${TVM_PREBUILD_PATH}")
add_library(tvm SHARED IMPORTED)
set_target_properties(tvm PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm.so"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
)
add_library(tvm_runtime SHARED IMPORTED)
set_target_properties(tvm_runtime PROPERTIES
IMPORTED_LOCATION "${TVM_PREBUILD_PATH}/libtvm_runtime.so"
INTERFACE_INCLUDE_DIRECTORIES "${TVM_PREBUILD_PATH}/../include"
)
else()
message(STATUS "Building TVM from source at ${TVM_SOURCE_DIR}")
add_subdirectory(${TVM_SOURCE_DIR} tvm EXCLUDE_FROM_ALL)
endif()
# Collect source files
tilelang_file_glob(GLOB TILE_LANG_SRCS
src/*.cc
src/layout/*.cc
src/transform/*.cc
src/op/*.cc
src/target/utils.cc
)
# Include CUDA source files if CUDA is enabled
if(USE_CUDA)
tilelang_file_glob(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc
src/target/codegen_cuda.cc
src/target/rt_mod_cuda.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS})
endif()
# Include ROCm source files if ROCm is enabled
if(USE_ROCM)
tilelang_file_glob(GLOB TILE_LANG_HIP_SRCS
src/target/codegen_hip.cc
src/target/rt_mod_hip.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS})
endif()
message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
# Add TileLang object library
add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS})
# Include directories for TileLang
set(TILE_LANG_INCLUDES
${TVM_SOURCE_DIR}/include
${TVM_SOURCE_DIR}/src
${TVM_SOURCE_DIR}/3rdparty/dlpack/include
${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include
)
# Find CUDA Toolkit
if(USE_CUDA)
find_package(CUDAToolkit REQUIRED)
if(NOT CUDAToolkit_FOUND)
message(FATAL_ERROR "CUDA Toolkit not found. Please set CUDAToolkit_ROOT.")
endif()
message(STATUS "CUDA Toolkit includes: ${CUDAToolkit_INCLUDE_DIRS}")
list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS})
endif(USE_CUDA)
# Find ROCM Toolkit
if(USE_ROCM)
find_rocm(${USE_ROCM})
message(STATUS "USE_ROCM: ${USE_ROCM}")
if(ROCM_FOUND)
# always set the includedir
# avoid global retrigger of cmake
include_directories(SYSTEM ${ROCM_INCLUDE_DIRS})
add_definitions(-D__HIP_PLATFORM_HCC__=1)
else()
message(FATAL_ERROR "ROCM Toolkit not found. Please set HIP_ROOT.")
endif(ROCM_FOUND)
message(STATUS "ROCM Toolkit includes: ${ROCM_INCLUDE_DIRS}")
list(APPEND TILE_LANG_INCLUDES ${ROCM_INCLUDE_DIRS})
endif(USE_ROCM)
# Define compile-time macros
set(TILE_LANG_COMPILE_DEFS
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>
__STDC_FORMAT_MACROS=1
PICOJSON_USE_INT64
)
# Set target properties for object library
target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES})
target_compile_definitions(tilelang_objs PRIVATE ${TILE_LANG_COMPILE_DEFS})
target_compile_definitions(tilelang_objs PRIVATE -DTILE_LANG_EXPORTS)
# Shared library
add_library(tilelang SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang PUBLIC tvm_runtime)
# Static library
add_library(tilelang_static STATIC $<TARGET_OBJECTS:tilelang_objs>)
add_dependencies(tilelang_static tvm_runtime)
set_target_properties(tilelang_static PROPERTIES OUTPUT_NAME tilelang)
# Debug build type-specific definitions
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
target_compile_definitions(tilelang PRIVATE "TVM_LOG_DEBUG")
target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG")
target_compile_definitions(tilelang_static PRIVATE "TVM_LOG_DEBUG")
endif()
# Module shared library
add_library(tilelang_module SHARED $<TARGET_OBJECTS:tilelang_objs>)
target_link_libraries(tilelang_module PUBLIC tvm)
# Install targets
if(TILE_LANG_INSTALL_STATIC_LIB)
install(TARGETS tilelang_static tvm_runtime
LIBRARY DESTINATION lib${LIB_SUFFIX}
)
else()
if(DEFINED TVM_PREBUILD_PATH)
install(TARGETS tilelang tilelang_module
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib${LIB_SUFFIX}
)
else()
install(TARGETS tvm_runtime tilelang tilelang_module
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib${LIB_SUFFIX}
)
endif()
endif()
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
# Contributing
That would be awesome if you want to contribute something to BitBLAS!
- [Contributing](CONTRIBUTING.md#contributing)
- [Reporting Bugs](CONTRIBUTING.md#reporting-bugs)
- [Asking Questions](CONTRIBUTING.md#asking-questions)
- [Submitting Pull Requests](CONTRIBUTING.md#submitting-pull-requests)
- [Repository Setup](CONTRIBUTING.md#repository-setup)
- [Running Tests](CONTRIBUTING.md#running-tests)
## Reporting Bugs
If you run into any weird behavior while using BitBLAS, feel free to open a new issue in this repository! Please run a **search before opening** a new issue, to make sure that someone else hasn't already reported or solved the bug you've found.
Any issue you open must include:
- Code snippet that reproduces the bug with a minimal setup.
- A clear explanation of what the issue is.
## Asking Questions
Please ask questions in issues.
## Submitting Pull Requests
All pull requests are super welcomed and greatly appreciated! Issues in need of a solution are marked with a [`♥ help`](https://github.com/ianstormtaylor/BitBLAS/issues?q=is%3Aissue+is%3Aopen+label%3A%22%E2%99%A5+help%22) label if you're looking for somewhere to start.
Please run `./format.sh` before submitting a pull request to make sure that your code is formatted correctly.
Please include tests and docs with every pull request!
## Repository Setup
To run the build, you need to have the BitBLAS repository cloned to your computer. After that, you need to `cd` into the directory where you cloned it, and install the dependencies with `python`:
```bash
python setup.py install
```
## Running Tests
To run the tests, start by building the project as described in the [Repository Setup](CONTRIBUTING.md#repository-setup) section.
Then you can rerun the tests with:
```text
python -m pytest testing
```
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
include VERSION
include CMakeLists.txt
include requirements.txt
include requirements-test.txt
include requirements-dev.txt
recursive-include src *
recursive-include 3rdparty *
recursive-exclude 3rdparty/clang* *
recursive-exclude 3rdparty/llvm* *
<!--- Licensed to the Apache Software Foundation (ASF) under one --> <div align="center">
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 --> # Tile Language
<!--- Unless required by applicable law or agreed to in writing, --> <img src=./images/logo-row.svg />
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
Tile Language (tile-lang) </div>
==============================================
Tile Language (tile-lang) is an extension of the Apache tvm designed to facilitate the development of simple yet high-performance GPU kernels. The project tile-lang currently supports CUDA devices with architectures including Ampere (sm_80+), Turing (sm_75), and Volta (sm_70).
This project is co-authored by [nox-410](https://github.com/nox-410) and [chengyupku](https://github.com/chengyupku) and [LeiWang1999](https://github.com/LeiWang1999). Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance.
Let's get started with a simple GEMM example. ## Tested Devices
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices:
- **NVIDIA GPUS**:
- H100 (**with Auto TMA/WGMMA Support**),
- A100
- V100
- RTX 4090
- RTX 3090
- RTX A600
- **AMD GPUS**:
- MI250 (**with Auto MatrixCore Support**)
- MI300 (**with Async Copy Support**)
## OP Implementation Examples
**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include:
- [Matrix Multiplication](./examples/gemm/)
- [Dequantization GEMM](./examples/dequantize_gemm/)
- [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/)
Within the `examples` repository, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention.
## Benchmark Summary
TileLang achieves exceptional performance across a variety of computational patterns. Below are selected results showcasing its capabilities:
- Operator Performance Vs. Baselines on H100
<div>
<img src="./images/op_benchmark_h100.png" alt="operator performance on H100" />
</div>
- MatrixCore FP16 GEMM Performance Vs. Baselines on MI300X
<div>
<img src="./images/op_benchmark_mi300_fp16_gemm_normalized_latency.png" alt="gemm fp16 performance on MI300X" />
</div>
## Installation
### Method 1: Install with Pip
The quickest way to get started is to install the latest release from PyPI:
```bash
pip install tilelang
```
Alternatively, you can install directly from the GitHub repository:
```bash
pip install git+https://github.com/microsoft/TileLang
```
Or install locally:
```bash
pip install . # with -e option if you want to install in editable mode
```
### Method 2: Build from Source
We currently provide three ways to install **tile-lang** from source:
- [Install from Source (using your own TVM installation)](./docs/Installation.md#install-from-source-with-your-own-tvm-installation)
- [Install from Source (using the bundled TVM submodule)](./docs/Installation.md#install-from-source-with-our-tvm-submodule)
- [Install Using the Provided Script](./docs/Installation.md#install-with-provided-script)
## Quick Start
In this section, you’ll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
### Basic GEMM Example
Below is a minimal example showing how to define and run a matrix multiplication kernel in tile-lang. This serves as a gentle introduction to the language’s key concepts.
```python ```python
import tvm.tl.language as T import tilelang
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype = "float"): from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
bias: T.Buffer([N], dtype),
): ):
# Define a GPU kernel launch configuration:
# - Grid dimension: (ceildiv(N, block_N), ceildiv(M, block_M))
# - Threads per block: 128
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate on-chip memory (shared and fragment buffers)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
bias_local = T.alloc_fragment((block_N,), dtype)
# Initialize the accumulation buffer
T.clear(C_local) T.clear(C_local)
# Primary compute loop, with pipelining across chunks of size block_K
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy a tile of A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
# Copy a tile of B into shared memory
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers into C_local
T.gemm(A_shared, B_shared, C_local) T.gemm(A_shared, B_shared, C_local)
T.copy(bias[bx * block_N], bias_local)
for i, j in T.Parallel(block_M, block_N): # Write the accumulated result from local memory back to global memory
C_local[i, j] += bias_local[j]
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
```
Despite this simple examples, tvm.tl can be used to write more complicated examples including convolutions, flash-attention-v2 (fwd & bwd), normalizations, these examples can be found under folder tl_scripts.
The performance of our flash-attention is comparable to the manually implementation. (see [Link](https://github.com/nox-410/tvm.tl/blob/tl/tl_doc/flash_perf.md)). # 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(1024, 1024, 1024, 128, 128, 32)
rt_mod, params = tilelang.lower(func)
## Install # 2. Create a Profiler object for running performance and correctness tests
profiler = Profiler(rt_mod, params, result_idx=[2])
Install is similar to tvm. First, fill in USE_CUDA and USE_LLVM in cmake/config.cmake, like this: # 3. Test the kernel in Python with PyTorch data
```bash import torch
set(USE_LLVM "/path/to/llvm-config --link-static")
set(HIDE_PRIVATE_SYMBOLS ON) # Create random input tensors on the GPU
set(USE_CUDA /usr/local/cuda) a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
``` b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
Then build tvm
```bash # Run the kernel through the Profiler
mkdir -p build && cd build && cp ../cmake/config.cmake . && cmake .. && make -j && cd - c = profiler(a, b)
export PYTHONPATH="$PYTHONPATH:$PWD/python"
# some python package required by tvm # Reference multiplication using PyTorch
pip install torch attrs cloudpickle decorator psutil synr tornado xgboost ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = rt_mod.imported_modules[0].get_source()
print("Generated CUDA kernel:\n", cuda_source)
``` ```
We also need to prepare the cutlass headers, the default version of cutlass in TVM does not work correctly
```bash ### Enhanced Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
git clone https://github.com/NVIDIA/cutlass.git -b v3.2.2
export TL_CUTLASS_PATH=/path/to/cutlass/include Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
```python
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Kernel configuration remains similar
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Enable rasterization for better L2 cache locality
T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, k * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j]
# Perform a tile-level GEMM on the shared buffers
T.gemm(A_shared, B_shared, C_local)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
``` ```
Note 1: It is recommeneded to use the latest cuda toolkit, because we requires nvcc to jit compile the generated CUDA code.
Note 2: Don't forget to clone the submodules. ### Dive Deep into TileLang Beyond GEMM
In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilzing magic layout transformation and intrins to accelerate dequantize gemm.
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.
More operators will continuously be added.
---
TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS).
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.
## Acknowledgements
## Language reference We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions.
Still in progress.
See tl_doc/language_ref.md This project was initiated by [yining shi](https://github.com/nox-410), and continued by [lei wang](https://github.com/LeiWang1999) and [yu cheng](https://github.com/chengyupku). It was completed under the guidance of [yuqing xia](https://github.com/xiayuqing0622), [lingxiao ma](https://github.com/xysmlx) and [jilong xue](https://github.com/jlxue) from [MSRA System Research Group](https://www.microsoft.com/en-us/research/group/systems-and-networking-research-group-asia/).
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.4 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->
# Support
Welcome to the TileLang support page! TileLang extends Apache TVM with a more accessible approach to writing high-performance GPU kernels. It currently supports CUDA targets including Ampere (sm_80+), Turing (sm_75), and Volta (sm_70) architectures. Whether you are working on common operators like GEMM and convolution or more advanced features like flash attention, TileLang aims to provide a more streamlined development experience while maintaining performance on par with hand-optimized implementations.
## How to Report Issues and Request Features
### Bug Reports and Feature Requests
We encourage you to use our GitHub Issues page to report any bugs or request new features:
1. Search Existing Issues: Before filing a new issue, please check if a similar one already exists.
2. File a New Issue: If you don’t find a matching entry, open a new issue and include as many details as possible—such as environment info, steps to reproduce, and the output logs. This will help us quickly understand and address your problem.
### Getting Help and Asking Questions
If you have questions about using TileLang, best practices, or performance tuning, there are several ways to get support:
• GitHub Discussions: Join the community at TileLang Discussions to ask questions, share ideas, and discuss development strategies.
• Stack Overflow: Use the TileLang tag when asking questions. The project maintainers and community members regularly check the tag and can offer assistance.
## Microsoft Support Policy
This project is open-source and community-driven. Primary support channels are the community forums and issue tracker mentioned above. While maintainers and contributors strive to respond promptly, we rely on community engagement to help address questions and improve the codebase.
## Contributing to TileLang
We encourage contributions from anyone interested in improving TileLang. Contributions can range from code enhancements and feature implementations to documentation improvements and bug fixes. If you’re interested in contributing, please refer to our CONTRIBUTING.md file for guidelines, including the process for signing the Contributor License Agreement (CLA), which you only need to complete once.
Your involvement helps shape TileLang’s future, ensuring it remains a versatile and high-performance tool for GPU kernel development.
This revised support page contextualizes the assistance and community channels around TileLang, while ensuring it is distinct from the original README and the previously provided content.
BitBLAS uses third-party material as listed below. The attached notices are
provided for informational purposes only.
Notice for apache/tvm
-------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
Notice for IST-DASLab/marlin/
-------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
0.0.1.dev
\ No newline at end of file
import argparse
import itertools
import logging
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import autotune, jit
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
from bitblas.base.arch import CUDA
from bitblas.base.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
ir_module = matmul_select_implementation(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
)
roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
if with_roller:
# check out bitblas is installed
try:
import bitblas # noqa: F401
except ImportError as e:
raise ImportError(
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=5,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return kernel()
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=8192, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=8192, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=8192, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
action="store_true",
help="Whether to enable BitBLAS roller for search space",
)
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
with_roller = args.with_roller
# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_latency, best_config, ref_latency = matmul(M, N, K, with_roller)
# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment