README.md 8.63 KB
Newer Older
1
2
<img src=./images/logo-row.svg />

3
<div align="center">
4

5
# Tile Language
6

7
</div>
8

9
Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance.
10

11
<img src=./images/MatmulExample.png />
12

Lei Wang's avatar
Lei Wang committed
13
14
15
## Latest News
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!

16
## Tested Devices
17
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
18
19
20
21
22
23
24
25
26

## OP Implementation Examples
**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include:

- [Matrix Multiplication](./examples/gemm/)
- [Dequantization GEMM](./examples/dequantize_gemm/)
- [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/)

27
28
Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.

29
30
31

## Benchmark Summary

32
TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities:
33

34
- Flash Attention Performance on H100
35

36
  <div align="center">    <img src="./images/mha_performance_h100.png" alt="operator performance on H100" width=80% />
37
38
  </div>

39
- Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X)
40
41

  <div>
42
    <img src="./images/op_benchmark_consistent_gemm_fp16.png" alt="gemm fp16 performance on Gpus" />
43
44
  </div>

Lei Wang's avatar
Lei Wang committed
45
46
47
48
49
50
- Dequantize Matmul Performance on A100

  <div>
    <img src="./images/op_benchmark_a100_wq_gemv.png" alt="dequantize gemv performance on A100" />
  </div>

51
52
53
54
55
56
57
58
59
60
61
62
## Installation
### Method 1: Install with Pip

The quickest way to get started is to install the latest release from PyPI:

```bash
pip install tilelang
```

Alternatively, you can install directly from the GitHub repository:

```bash
63
pip install git+https://github.com/tile-ai/tilelang
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
```

Or install locally:

```bash
pip install .  # with -e option if you want to install in editable mode
```

### Method 2: Build from Source
We currently provide three ways to install **tile-lang** from source:
 - [Install from Source (using your own TVM installation)](./docs/Installation.md#install-from-source-with-your-own-tvm-installation)
 - [Install from Source (using the bundled TVM submodule)](./docs/Installation.md#install-from-source-with-our-tvm-submodule)
 - [Install Using the Provided Script](./docs/Installation.md#install-with-provided-script)


## Quick Start

In this section, you’ll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.

Lei Wang's avatar
Lei Wang committed
83
### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
84
85
86
87

Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.

```python
88
import tilelang
89
90
91
92
93
94
95
96
97
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
    make_mma_swizzle_layout as make_swizzle_layout,)

def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
98
    # add decorator @tilelang.jit if you want to return a torch function
99
100
101
102
103
104
    @T.prim_func
    def main(
        A: T.Buffer((M, K), dtype),
        B: T.Buffer((K, N), dtype),
        C: T.Buffer((M, N), dtype),
    ):
105
        # Initialize Kernel Context
106
107
108
109
110
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype)

Lei Wang's avatar
Lei Wang committed
111
112
            # Apply layout optimizations or define your own layout (Optional)
            # If not specified, we will deduce the layout automatically
113
114
115
116
            # T.annotate_layout({
            #     A_shared: make_swizzle_layout(A_shared),
            #     B_shared: make_swizzle_layout(B_shared),
            # })
117

Lei Wang's avatar
Lei Wang committed
118
            # Enable rasterization for better L2 cache locality (Optional)
119
            # T.use_swizzle(panel_size=10, enable=True)
120
121
122
123

            # Clear local accumulation
            T.clear(C_local)

124
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
125
                # Copy tile of A
Lei Wang's avatar
Lei Wang committed
126
                # This is a sugar syntax for parallelized copy
127
                T.copy(A[by * block_M, ko * block_K], A_shared)
128
129

                # Demonstrate parallelized copy from global to shared for B
130
131
                for k, j in T.Parallel(block_K, block_N):
                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
132
133

                # Perform a tile-level GEMM on the shared buffers
Lei Wang's avatar
Lei Wang committed
134
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
135
136
137
138
139
140
                T.gemm(A_shared, B_shared, C_local)

            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main
141
142


143
# 1. Define the kernel (matmul) with the desired dimensions
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
func = matmul(1024, 1024, 1024, 128, 128, 32)

# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)


160
# Run the kernel through the JIT-compiled function
161
162
163
164
165
166
167
168
169
170
171
172
173
c = jit_kernel(a, b)

# Reference multiplication using PyTorch
ref_c = a @ b

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)

174
# 5.Pofile latency with the profiler
175
176
177
178
179
profiler = jit_kernel.get_profiler()

latency = profiler.do_bench()

print(f"Latency: {latency} ms")
180
181
```

182
183
184
185
### Dive Deep into TileLang Beyond GEMM

In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:

186
- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
187
188
189
190
191
192
193
194
195
- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.

---

TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS).

## Acknowledgements
196

197
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410). Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.