"docs/vscode:/vscode.git/clone" did not exist on "ccfa0841253d6da17d0dcb765c5e849abf95d69c"
Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
FROM nvcr.io/nvidia/pytorch:23.01-py3
WORKDIR /root
RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh -O install_miniconda.sh && \
bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1
RUN conda install pip cmake && conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN git clone https://github.com/microsoft/TileLang.git --recursive -b main TileLang \
&& cd TileLang && ./install.sh
CMD bash
To ease the process of installing all the dependencies, we provide a Dockerfile and a simple guideline to build a Docker image with all of above installed. The Docker image is built on top of Ubuntu 20.04, and it contains all the dependencies required to run the experiments. We only provide the Dockerfile for NVIDIA GPU, and the Dockerfile for AMD GPU will be provided upon request.
```bash
git clone --recursive https://github.com/microsoft/TileLang TileLang
cd TileLang/docker
# build the image, this may take a while (around 10+ minutes on our test machine)
docker build -t tilelang_cuda -f Dockerfile.cu120 .
# run the container
docker run -it --cap-add=SYS_ADMIN --network=host --gpus all --cap-add=SYS_PTRACE --shm-size=4G --security-opt seccomp=unconfined --security-opt apparmor=unconfined --name tilelang_test tilelang_cuda bash
```
# Installation Guide
## Installing with pip
**Prerequisites for installation via wheel or PyPI:**
- **Operating System**: Ubuntu 20.04 or later
- **Python Version**: >= 3.8
- **CUDA Version**: >= 11.0
The easiest way to install TileLang is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal.
**Note**: Currently, TileLang whl is only supported on Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=11.0 and with Python>=3.8. **If you are using a different platform or environment, you may need to [build TileLang from source](https://github.com/microsoft/TileLang/blob/main/docs/Installation.md#building-from-source).**
```bash
pip install tilelang
```
Alternatively, you may choose to install TileLang using prebuilt packages available on the Release Page:
```bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```
To install the latest version of TileLang from the github repository, you can run the following command:
```bash
pip install git+https://github.com/microsoft/TileLang.git
```
After installing TileLang, you can verify the installation by running:
```bash
python -c "import tilelang; print(tilelang.__version__)"
```
## Building from Source
**Prerequisites for building from source:**
- **Operating System**: Linux
- **Python Version**: >= 3.7
- **CUDA Version**: >= 10.0
We recommend using a docker container with the necessary dependencies to build TileLang from source. You can use the following command to run a docker container with the necessary dependencies:
```bash
docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```
To build and install TileLang directly from source, follow the steps below. This process requires certain pre-requisites from apache tvm, which can be installed on Ubuntu/Debian-based systems using the following commands:
```bash
sudo apt-get update
sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
```
After installing the prerequisites, you can clone the TileLang repository and install it using pip:
```bash
git clone --recursive https://github.com/Microsoft/TileLang.git
cd TileLang
pip install . # Please be patient, this may take some time.
```
if you want to install TileLang with the development mode, you can run the following command:
```bash
pip install -e .
```
We currently provide three ways to install **tile-lang**:
- [Install from Source (using your own TVM installation)](#install-from-source-with-your-own-tvm-installation)
- [Install from Source (using the bundled TVM submodule)](#install-from-source-with-our-tvm-submodule)
- [Install Using the Provided Script](#install-with-provided-script)
### Method 1: Install from Source (using your own TVM installation)
If you already have a compatible TVM installation, follow these steps:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/Microsoft/TileLang
cd TileLang
```
> **Note**: Use the `--recursive` flag to include necessary submodules.
2. **Configure Build Options:**
Create a build directory and specify your existing TVM path:
```bash
mkdir build
cd build
cmake .. -DTVM_PREBUILD_PATH=/your/path/to/tvm/build # e.g., /workspace/tvm/build
make -j 16
```
3. **Set Environment Variables:**
Update `PYTHONPATH` to include the `tile-lang` Python module:
```bash
export PYTHONPATH=/your/path/to/tile-lang/python:$PYTHONPATH
# TVM_IMPORT_PYTHON_PATH is used by 3rdparty framework to import tvm
export TVM_IMPORT_PYTHON_PATH=/your/path/to/tvm/python
```
### Method 2: Install from Source (using the bundled TVM submodule)
If you prefer to use the built-in TVM version, follow these instructions:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/Microsoft/TileLang
cd TileLang
```
> **Note**: Ensure the `--recursive` flag is included to fetch submodules.
2. **Configure Build Options:**
Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA):
```bash
mkdir build
cp 3rdparty/tvm/cmake/config.cmake build
cd build
echo "set(USE_LLVM ON)" >> config.cmake
echo "set(USE_CUDA ON)" >> config.cmake
# or echo "set(USE_ROCM ON)" >> config.cmake if want to enable rocm runtime
cmake ..
make -j 16
```
The build outputs (e.g., `libtilelang.so`, `libtvm.so`, `libtvm_runtime.so`) will be generated in the `build` directory.
3. **Set Environment Variables:**
Ensure the `tile-lang` Python package is in your `PYTHONPATH`:
```bash
export PYTHONPATH=/your/path/to/TileLang/python:$PYTHONPATH
```
### Method 3: Install Using the Provided Script
For a simplified installation, use the provided script:
1. **Clone the Repository:**
```bash
git clone --recursive https://github.com/Microsoft/TileLang
cd TileLang
```
2. **Run the Installation Script:**
```bash
bash install.sh
# or bash `install_amd.sh` if you want to enable rocm runtime
```
This script automates the setup, including submodule initialization and configuration.
The flash-attention performance on RTX-4090 GPU, with cuda toolkit 12.2
SEQ_LEN is fixed to 2k, All matmul use fp16->fp32 mma, value in TFlops, higher is better.
Flash-Forward
| CASUAL,DIM | Flash_attn | Tvm.tl |
| --------- | ---------- | ------ |
| False, 32 | 159.79 | 156.82 |
| False, 64 | 168.91 | 166.84 |
| False, 128 | 169.28 | 166.51 |
| False, 256 | 156.15 | 166.77 |
| True, 32 | 126.78 | 142.59 |
| True, 64 | 142.23 | 152.43 |
| True, 128 | 151.19 | 156.30 |
| True, 256 | 144.12 | 151.54 |
Flash-backward
| CASUAL,DIM | Flash_attn | Tvm.tl |
| --------- | ---------- | ------ |
| False, 32 | 115.12 | 120.03 |
| False, 64 | 124.81 | 130.94 |
| False, 128 | 124.57 | 122.99 |
| True, 32 | 86.48 | 95.66 |
| True, 64 | 96.53 | 106.03 |
| True, 128 | 99.23 | 100.24 |
# TVM.TL language reference
## T.Kernel
args: the grid size (0-3 dimension) and the num_threads.
returns: the blockIdx variables
launch a kernel, it must be used in a with statement. There can be multiple kernels launched sequentially inside a prim function.
## T.alloc_shared
args: shape, dtype
returns: Buffer
Allocate buffer on shared memory, It must be used within T.Kernel scope and should be allocated at the top of the scope.
Dynamic shared memory is used.
## T.alloc_fragment
args: shape, dtype
returns: Buffer
Allocate buffer on register memory, It must be used within T.Kernel scope and should be allocated at the top of the scope.
The shape represents the whole shape of the buffer. Each element in the buffer is distributed stored on each threads, this storage partition will be inferred by the compiler.
## T.copy
args: src, dst
Copies data from src to dst, src and dst can be one of (Buffer, BufferLoad, BufferRegion). If you use BufferLoad that represents a single starting point, the other params should not be BufferLoad, since we need to know the copy region.
Zero will be padded if we detect the load is out of boundary.
## T.gemm
args: A, B, C, transpose_A, transpose_B, policy
Performs gemm operation on A, B and C. C must be a fragment, B must be on shared memory, A can be either a fragment or shared.
Note that the current implementation has some shape and dtype constraints, for example, the length of reduction axis must be a multiple of 32 for fp16 multiplicand case, we will update this later.
## T.reduce_max T.reduce_sum
args: src, dst, dim
Performs a reduce operation from src to dst on dimension dim. Currently we only support src and dst to be a fragment.
## T.Parallel
You can use T.Parallel to write a loop. The loop will be partitioned to all the threads by the compiler (The compiler will consider vectorize size, the fragment's thread mapping ... ). Note that this is the only way you can perform arbitrary operation on fragments.
## T.Pipelined
args: start, stop, num_stages
Pipeline the loop, copy from the global memory will be converted to async operations and reordered to the point after it is consumed. num_stages is the number of buffer between producer-consumer. (e.g. Double buffer when num_stages=2)
## T.clear T.fill
nothing special, they will be converted to T.Parallel
## T.use_swizzle
Optimization for L2 cache. The launch of blockIdx.x and blockIdx.y will be serpentined.
You need to add it in a kernel after buffer is all allocated.
import torch
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [1, 2, 3, 4]
threads = [128, 256]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs
def convolution(N, C, H, W, F, K, S, D, P, tune=False):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func
def main(
data: T.Buffer((N, H, W, C), dtype),
kernel: T.Buffer((KH, KW, C, F), dtype),
out: T.Buffer((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Buffer((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Buffer((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[2],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
def ref_program(A, B, stride, padding, dilation):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
total_flops = 2 * N * C * OH * OW * F * K * K
if (not args.tune):
program = convolution(
N, C, H, W, F, K, S, D, P, tune=args.tune)(
block_M=256, block_N=128, block_K=64, num_stages=4, threads=256)
ref_program = partial(ref_program, stride=S, padding=P, dilation=D)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, ref_latency = convolution(
N, C, H, W, F, K, S, D, P, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
### Dequantization GEMM
An example of implementing a dequantization GEMM:
```python
@T.prim_func
def dequant_matmul(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
T.clear(Ct_local)
for k in T.Pipelined(
T.ceildiv(K, block_K),
num_stages=num_stages
):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)(
num_bits,
B_local[i, j // 2],
j % 2,
dtype=in_dtype,
)
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct[bx * block_N, by * block_M])
```
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import Profiler
import tilelang.language as T
from tilelang.autotuner import *
from tilelang import tvm
from tvm import tir
import itertools
import torch
import argparse
from functools import partial
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
# s1e2n1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = f4 & 7
e_f16 = e_f4 | 8
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(
T.ceildiv(K, block_K),
num_stages=1
):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = test_convert(
N,
K,
block_N,
block_K,
"float16",
)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def get_configs():
block_M = [128]
block_N = [128, 256]
block_K = [128]
num_stages = [2]
threads = [256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [
{'block_M': c[0], 'block_N': c[1], 'block_K': c[2], 'num_stages': c[3], 'threads': c[4], 'split': c[5]}
for c in _configs
]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
KK = K // split
@T.prim_func
def main_split(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
SplitC = T.alloc_buffer(
[split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype
)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
for k in range(split):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10
)
@jit(out_idx=[2], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None, profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
total_flops = 2 * M * N * K
if (not args.tune):
program = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)(block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, ref_latency = matmul(M, N, K, "float16", "float16", "float32", num_bits=4, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
# FlashAttention
Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way.
```python
@T.prim_func
def flash_attention(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
# Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
# bx: block index in sequence dimension
# by: block index in "heads" dimension
# bz: block index in "batch" dimension
# threads=thread_num means how many threads per block
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
# Allocate shared memory for Q, K, V to reduce global memory accesses
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Allocate buffers on register
# acc_s: buffer to hold intermediate attention scores
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
# acc_s_cast: buffer for storing casted/adjusted scores
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
# acc_o: partial accumulation of output
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
# Buffers to track per-row maximum score and related stats
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
# Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
# Copy a block of Q from global memory to Q_shared
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
# Initialize accumulators
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)
)
# Pipeline the loop to overlap copies/gemm stages
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Copy K block into shared memory
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
)
else:
T.clear(acc_s)
# Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed,
# policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row
# of acc_s, and the resulting acc_s is retained in registers.
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Copy V block into shared memory
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
for i, j in T.Parallel(block_M, dim):
acc_s[i, j] *= scale
# Save old scores_max, then reset scores_max
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
# Rescale the partial output accumulation to keep exponents consistent
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# Exponentiate (scores - max) for the new block
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
# Make a cast of acc_s to fp16 for the next GEMM
T.copy(acc_s, acc_s_cast)
# Multiply the attention acc_s_cast by V and add to partial output (acc_o)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
# Update the "logsum" tracker with the newly accumulated sum
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
# Final step: divide each partial output by logsum (completing the softmax)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
# Write back the final output block from acc_o to the Output buffer
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
```
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_casual, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Buffer(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_casual else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_casual):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_casual:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_casual', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, seq_len, dim, is_casual = args.batch, args.heads, args.seq_len, args.dim, args.is_casual
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_casual:
total_flops *= 0.5
if (not args.tune):
program = flashattn(
batch, heads, seq_len, dim, is_casual, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
ref_program = partial(ref_program, is_casual=is_casual)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_casual, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
# TileLang GEMM (Matrix Multiplication) Examples
TileLang is a domain-specific language designed to simplify the process of writing GPU kernels. It provides high-level abstractions for memory allocation, scheduling, and tiling, which are critical for achieving maximum performance on modern hardware architectures like NVIDIA GPUs. This README demonstrates how to write and optimize a matrix multiplication (GEMM) kernel using TileLang.
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
---
## Getting Started
### Prerequisites
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation
```bash
pip install tilelang bitblas
```
*(Adjust accordingly if you are installing from source or using a different environment.)*
---
## Simple GEMM Example
Below is a basic matrix multiplication (GEMM) example demonstrating how TileLang handles buffer allocation, tiling, and kernel dispatch. For simplicity, we'll multiply two 1024×1024 matrices using 128 threads/block.
```python
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Define a grid with enough blocks to cover M×N
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared memory for the current tile of A and B
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
# Allocate a local (register) fragment for partial accumulations
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Initialize the local accumulation buffer to zero
T.clear(C_local)
# Loop over the K dimension in block_K chunks, using a 3-stage pipeline
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy from global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a matrix multiply-accumulate on the tile
T.gemm(A_shared, B_shared, C_local)
# Copy the accumulated result from local memory (C_local) to global memory (C)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
### Code Walkthrough
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
T.gemm(...)
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
Writes the final computed tile from registers/shared memory to global memory.
### Compiling and Profiling
```python
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func) # Prints an IR-like representation of the TileLang kernel
rt_mod, params = tilelang.lower(func)
profiler = Profiler(rt_mod, params, result_idx=[2])
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
ref_c = a @ b
# Validate results
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Kernel Source
print(rt_mod.imported_modules[0].get_source())
```
---
## Advanced GEMM Features
### Custom Memory Layout / Swizzling
**Swizzling** rearranges data in shared memory or global memory to mitigate bank conflicts, improve cache utilization, and better match the GPU’s warp execution pattern. TileLang provides helper functions like `make_swizzle_layout` to annotate how buffers should be laid out in memory.
### Parallel Copy and Auto-Pipelining
- **Parallel Copy** allows you to distribute the copy of a block tile across all threads in a block, speeding up the transfer from global memory to shared memory.
- **Auto-Pipelining** uses multiple stages to overlap copying with computation, reducing idle cycles.
### Rasterization for L2 Cache Locality
Enabling **swizzle (rasterization)** at the kernel level can improve data reuse and reduce cache thrashing in L2. This is especially important when matrices are large.
---
## Enhanced GEMM Example with Annotations
Below is a more advanced snippet that showcases how to apply memory layouts, enable swizzling, and parallelize the copy operations to maximize performance:
```python
import tilelang.language as T
# `make_mma_swizzle_layout` is a python-defined layout function
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared and local fragments
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Annotate memory layout
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Enable swizzle-based rasterization for better L2 locality
T.use_swizzle(panel_size=10, enable=True)
# Clear the local accumulation buffer
T.clear(C_local)
# Pipelined iteration over K dimension
for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, idx * block_K], A_shared)
# Parallel copy tile of B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j]
# Perform local GEMM on the shared-memory tiles
T.gemm(A_shared, B_shared, C_local)
# Copy the result tile back
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
---
## Verifying Correctness
Once you compile and load your kernel into a runtime module (`rt_mod`), you can use tools like **PyTorch** to easily create random matrices on the GPU, run your TileLang kernel, and compare the results to a reference implementation (e.g., `torch.matmul` or `@` operator).
```python
import torch
# Suppose your compiled kernel is in rt_mod
profiler = Profiler(rt_mod, params, result_idx=[2])
A = torch.randn(1024, 1024).cuda().half()
B = torch.randn(1024, 1024).cuda().half()
C_tilelang = profiler(A, B)
C_ref = A @ B
torch.testing.assert_close(C_tilelang, C_ref, rtol=1e-2, atol=1e-2)
print("Results match!")
```
---
## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow
```python
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
```
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki, thread_bindings=thread_bindings)
# Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki, thread_bindings=thread_bindings)
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
\]
where each thread in the warp calculates a small portion of the final tile. For instance:
```python
mma_emitter.mma(A_local, B_local, C_local)
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared, thread_bindings=thread_bindings)
```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.
### Summary
By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with manual thread bindings and memory allocations, you can replicate the control and performance of raw CUDA at the TileLang level. This approach is best suited for expert users who are comfortable with GPU warp-level programming, since it does require a deep understanding of hardware concurrency, memory hierarchies, and scheduling. However, the payoff can be significant for performance-critical paths, where every byte of bandwidth and every cycle of latency must be carefully orchestrated.
---
## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(
M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"
):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128
) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
rt_mod, params = tilelang.lower(func)
profiler = Profiler(rt_mod, params, result_idx=[2])
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
print(rt_mod.imported_modules[0].get_source())
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
M, N, K = 128, 128, 128
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float16"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 Cache Locality
T.use_swizzle(panel_size=10)
# Clear the local buffer
T.clear(C_local)
# Auto pipeline the computation
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
# Instead of using
# T.copy(B[k * block_K, bx * block_N], B_shared)
# we can also use Parallel to auto map the thread
# bindings and vectorize the copy operation.
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
rt_mod, params = tilelang.lower(func)
profiler = Profiler(rt_mod, params, result_idx=[2])
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
print(rt_mod.imported_modules[0].get_source())
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import torch
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
return out
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
"""
Argument:
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
C: (batch, seqlen, ngroups, dstate)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
_, _, ngroups, _, _ = cb.shape
batch, seqlen, nheads, headdim = x.shape
# _, _, ngroups, dstate = B.shape
# assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
# assert C.shape == B.shape
# B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
# CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
# rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
# (batch, nheads, nchunks, chunksize, chunksize)
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
if D.dim() == 1:
D = rearrange(D, "h -> h 1")
out = out + x * D
return out
def get_configs():
block_M = [64, 128, 256]
block_N = [32, 64]
block_K = [64, 128, 256]
block_Dstate = [128]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'block_Dstate': c[3],
'num_stages': c[4],
'threads': c[0] * 2
} for c in _configs]
return configs
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):
@T.prim_func
def main(cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),
x: T.Buffer((batch, seqlen, nheads, headdim), dtype), dt: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), C: T.Buffer(
(batch, seqlen, ngroups, dstate), dtype), prev_states: T.Buffer(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Buffer(
(nheads), dtype), Output: T.Buffer(
(batch, seqlen, nheads, headdim), dtype)):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
cb_local = T.alloc_fragment((block_M, block_K), dtype)
dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
dt_local = T.alloc_fragment((block_K), accum_dtype)
x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
scale_m_local = T.alloc_fragment((block_M), accum_dtype)
C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
D_local = T.alloc_fragment((1), accum_dtype)
x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)
batch_idx = by % batch
chunk_idx = by // batch
# m: chunk_size
# n : headdim
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p -
dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size +
m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz,
n_idx * block_N:(n_idx + 1) * block_N])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[7],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
def kernel(block_M=None,
block_N=None,
block_K=None,
block_Dstate=None,
num_stages=None,
threads=None):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
if (not args.tune):
program = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [7], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools
def chunk_state_triton(B, x, dt, dA_cumsum):
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd
return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False)
def ref_program(B, x, dt, dA_cumsum):
"""
Argument:
B: (batch, seqlen, ngroups, headdim)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
Return:
states: (batch, nchunks, nheads, headdim, dstate)
"""
# Check constraints.
batch, seqlen, nheads, headdim = x.shape
dstate = B.shape[-1]
_, _, nchunks, chunk_size = dt.shape
assert seqlen <= nchunks * chunk_size
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
ngroups = B.shape[2]
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if seqlen < nchunks * chunk_size:
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
dt.to(x.dtype), x)
def get_configs():
block_M = [64, 128]
block_N = [32, 64, 128]
block_K = [32, 64]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[0] * 2
} for c in _configs]
return configs
def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func
def main(B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), x: T.Buffer(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Buffer(
(batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel(
nheads,
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
dt_shared = T.alloc_shared((block_K), dtype)
dA_cumsum_shared = T.alloc_shared((block_K), dtype)
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
scale = T.alloc_fragment((block_K), accum_dtype)
dA_cs_last = T.alloc_fragment((1), accum_dtype)
dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype)
dt_local = T.alloc_fragment((block_K), accum_dtype)
loop_range = T.ceildiv(chunk_size, block_K)
batch_idx = by % batch
chunk_idx = by // batch
m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N)
T.annotate_layout({
x_shared: tilelang.layout.make_swizzled_layout(x_shared),
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
})
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K):
scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i]
T.copy(x_shared, x_local)
for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j]
T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz // (nheads // ngroups),
n_idx * block_N:(n_idx + 1) * block_N], B_shared)
T.gemm(xt_local, B_shared, acc_o)
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
n_idx * block_N:(n_idx + 1) * block_N])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[4],
supply_type=tilelang.TensorSupplyType.Normal,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
total_flops = 2 * batch * seq_len * heads * dim * dstate
if (not args.tune):
program = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=128, block_K=64, num_stages=4, threads=128)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [4], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment