README.md 13.1 KB
Newer Older
1
2
<img src=./images/logo-row.svg />

3
<div align="center">
4

5
# Tile Language
6
7
[![PyPI version](https://badge.fury.io/py/tilelang.svg)](https://badge.fury.io/py/tilelang)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tile-ai/tilelang) [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.gg/TUrHyJnKPG)
8

9
</div>
10

11
Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance.
12

13
<img src=./images/MatmulExample.png />
14

Lei Wang's avatar
Lei Wang committed
15
## Latest News
16
17
18
- 12/18/2025 🚀: Added [CuTeDSL backend](https://github.com/tile-ai/tilelang/pull/1421) support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: [Issue #1454](https://github.com/tile-ai/tilelang/issues/1454).
- 12/17/2025 🔬: Integrated [Z3 theorem prover](https://github.com/tile-ai/tilelang/pull/1367) into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification!
- 10/31/2025 🔧: Migrated to [apache-tvm-ffi](https://github.com/tile-ai/tilelang/pull/1108), significantly reducing CPU overhead!
19
- 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8.
20
- 10/07/2025 🍎: Added Apple Metal Device support, check out [Pull Request #799](https://github.com/tile-ai/tilelang/pull/799) for details.
21
- 09/29/2025  🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported!
22
23
24
25
Check out the preview here:
🔗 [link](https://github.com/tile-ai/tilelang-ascend).
This includes implementations across two branches:
[ascendc_pto](https://github.com/tile-ai/tilelang-ascend) and
26
[npuir](https://github.com/tile-ai/tilelang-ascend/tree/npuir).
27
Feel free to explore and share your feedback! 
28
- 07/04/2025 🚀: Introduced `T.gemm_sp` for 2:4 sparse tensor core support, check out [Pull Request #526](https://github.com/tile-ai/tilelang/pull/526) for details.
29
30
- 06/05/2025 ✨: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates!
- 04/14/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details.
31
- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this.
32
- 02/15/2025 ✨: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
Lei Wang's avatar
Lei Wang committed
33
- 02/12/2025 ✨: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)!
34
- 02/10/2025 🚀: Added debug tools for TileLang—`T.print` for printing variables/buffers ([docs](https://tilelang.com/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)).
Lei Wang's avatar
Lei Wang committed
35
36
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!

37
## Tested Devices
38
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
39
40
41
42
43
44
45
46

## OP Implementation Examples
**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include:

- [Matrix Multiplication](./examples/gemm/)
- [Dequantization GEMM](./examples/dequantize_gemm/)
- [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/)
47
- [Flash MLA Decoding](./examples/deepseek_mla/)
48
- [Native Sparse Attention](./examples/deepseek_nsa/)
49

50
51
Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.

52
53
54

## Benchmark Summary

55
TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities:
56

57
58
59
60
61
62
63
64
65
66
67
- MLA Decoding Performance on H100

  <div style="display: flex; gap: 10px; justify-content: center;">
    <div style="flex: 1;">
      <img src="./examples/deepseek_mla/figures/bs64_float16.png" alt="mla decode performance bs64 on H100" width="100%" />
    </div>
    <div style="flex: 1;">
      <img src="./examples/deepseek_mla/figures/bs128_float16.png" alt="mla decode performance bs128 on H100" width="100%" />
    </div>
  </div>
  
68
- Flash Attention Performance on H100
69

70
  <div align="center">    <img src="./images/mha_performance_h100.png" alt="operator performance on H100" width=80% />
71
72
  </div>

73
- Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X)
74
75

  <div>
76
    <img src="./images/op_benchmark_consistent_gemm_fp16.png" alt="gemm fp16 performance on Gpus" />
77
78
  </div>

Lei Wang's avatar
Lei Wang committed
79
80
81
82
83
84
- Dequantize Matmul Performance on A100

  <div>
    <img src="./images/op_benchmark_a100_wq_gemv.png" alt="dequantize gemv performance on A100" />
  </div>

85
86
87
88
89
90
91
92
93
94
95
96
## Installation
### Method 1: Install with Pip

The quickest way to get started is to install the latest release from PyPI:

```bash
pip install tilelang
```

Alternatively, you can install directly from the GitHub repository:

```bash
97
pip install git+https://github.com/tile-ai/tilelang
98
99
100
101
102
```

Or install locally:

```bash
103
104
105
106
# install required system dependencies
sudo apt-get update
sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

107
pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output
108
109
110
111
```

### Method 2: Build from Source
We currently provide three ways to install **tile-lang** from source:
112
 - [Install from Source (using your own TVM installation)](./docs/get_started/Installation.md#method-1-install-from-source-using-your-own-tvm-installation)
113
114
 - [Install from Source (using the bundled TVM submodule)](./docs/get_started/Installation.md#method-2-install-from-source-using-the-bundled-tvm-submodule)
 - [Install Using the Provided Script](./docs/get_started/Installation.md#method-3-install-using-the-provided-script)
115

116
117
118
119
120
121
122
123
124
125
### Method 3: Install with Nightly Version

For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**.

```bash
pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/
```

> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.
126
127
128

## Quick Start

129
In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
130

Lei Wang's avatar
Lei Wang committed
131
### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
132
133
134
135

Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.

```python
136
import tilelang
137
import tilelang.language as T
138
139
140
141
142

# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
143
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
144

145
    @T.prim_func
146
147
148
149
    def matmul_relu_kernel(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
150
    ):
151
        # Initialize Kernel Context
152
153
154
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
155
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
156

Lei Wang's avatar
Lei Wang committed
157
            # Enable rasterization for better L2 cache locality (Optional)
158
            # T.use_swizzle(panel_size=10, enable=True)
159
160
161
162

            # Clear local accumulation
            T.clear(C_local)

163
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
164
                # Copy tile of A
Lei Wang's avatar
Lei Wang committed
165
                # This is a sugar syntax for parallelized copy
166
                T.copy(A[by * block_M, ko * block_K], A_shared)
167

168
169
                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)
170
171

                # Perform a tile-level GEMM on the shared buffers
Lei Wang's avatar
Lei Wang committed
172
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
173
                T.gemm(A_shared, B_shared, C_local)
174
175
176
177
            
            # relu
            for i, j in T.Parallel(block_M, block_N):
                C_local[i, j] = T.max(C_local[i, j], 0)
178
179
180
181

            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

182
    return matmul_relu_kernel
183
184


185
M = 1024  # M = T.dynamic("m") if you want to use dynamic shape
186
187
188
189
190
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
191

192
193
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
194
195
196
197
198

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
199
200
201
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
202

203
204
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
205

206
print(c)
207
# Reference multiplication using PyTorch
208
ref_c = torch.relu(a @ b)
209
210
211
212
213
214

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
215
# cuda_source = matmul_relu_kernel.get_kernel_source()
216
# print("Generated CUDA kernel:\n", cuda_source)
217

218
219
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
220
221
222
223

latency = profiler.do_bench()

print(f"Latency: {latency} ms")
224
225
```

226
227
228
229
### Dive Deep into TileLang Beyond GEMM

In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:

230
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
231
232
233
234
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.

Lei Wang's avatar
Lei Wang committed
235
236
237
238
## Upcoming Features

Check our [tilelang v0.2.0 release plan](https://github.com/tile-ai/tilelang/issues/79) for upcoming features.

239
240
---

FeiyangChen's avatar
FeiyangChen committed
241
TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS) and [AttentionEngine](https://github.com/microsoft/AttentionEngine).
242

243
244
245
246
247
248
## Join the Discussion

Welcome to join our Discord community for discussions, support, and collaboration!

[![Join our Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?logo=discord&style=for-the-badge)](https://discord.gg/TUrHyJnKPG)

249
## Acknowledgments
250

251
We would like to express our gratitude to the [TVM](https://github.com/apache/tvm) community for their invaluable contributions. The initial version of this project was mainly developed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) with supervision from Prof. [Zhi Yang](https://yangzhihome.github.io) at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions.