Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
# 🚀 Write High Performance FlashMLA with TileLang on Hopper
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/chengyupku">Yu Cheng</a>
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
```{figure} ../_static/img/mla_hopper/bs64_float16.png
:width: 50%
:alt: Overview
:align: center
Figure 1: Performance under batch size=64
```
```{figure} ../_static/img/mla_hopper/bs128_float16.png
:width: 50%
:alt: Overview
:align: center
Figure 2: Performance under batch size=128
```
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
## Implementation
First, let's review the core computation logic of traditional FlashAttention:
```python
# acc_s: [block_M, block_N]
# scores_max: [block_M]
# scores_scale: [block_M]
# acc_o: [block_M, dim]
for i in range(loop_range):
acc_s = Q @ K[i]
scores_max_prev = scores_max
scores_max = max(acc_s, dim=1)
scores_scale = exp(scores_max_prev - scores_max)
acc_o *= scores_scale
acc_s = exp(acc_s - scores_max)
acc_o = acc_s @ V[i]
...
```
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
### Layout Inference
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
```{figure} ../_static/img/mla_hopper/qk_layout.jpg
:width: 50%
:alt: Overview
:align: center
Figure 3: Buffer shapes in Q @ K
```
```{figure} ../_static/img/mla_hopper/pv_layout.jpg
:width: 50%
:alt: Overview
:align: center
Figure 4: Buffer shapes in acc_s @ V
```
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
### Threadblock Swizzling
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
```python
T.use_swizzle(panel_size: int, order: str = "row")
```
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
### Shared Memory Swizzling
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
```python
T.annotate_layout({
S_shared: TileLang.layout.make_swizzled_layout(S_shared),
})
```
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
### Pipeline
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
```python
T.pipelined(range: int, stage: int)
```
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
### Split-KV
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
## 🚀 On AMD MI300X Accelerators
Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms.
### Architectural Considerations and Optimization Strategies
Key implementation differences between Hopper and MI300X architectures include:
1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations.
2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes:
- Reducing software pipeline stages
- Register-based caching of Q matrices instead of shared memory utilization:
```python
# Original shared memory allocation
Q_shared = T.alloc_shared([block_H, dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
# Optimized register allocation
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
```
3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64.
4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code.
### Performance Evaluation
We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate:
<figure style="text-align: center">
<a href="../figures/flashmla-amd.png">
<img src="../figures/flashmla-amd.png" alt="AMD FlashMLA Performance Comparison">
</a>
<figcaption style="text-align: center;">Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)</figcaption>
</figure>
Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) in most test cases, while significantly outperforming both Triton (1.98×) and PyTorch (3.76×) implementations. This performance is achieved through a concise 80-line Python implementation, demonstrating TileLang's efficiency and programmability advantages.
### Future Optimization Opportunities
1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction.
2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to:
- Reduce shared memory pressure
- Improve compute-to-memory access ratios
- Enhance parallelism through dimension-wise task distribution
# ElementWise Operators
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/chenghuaWang">Chenghua Wang</a>
</div>
:::{warning}
:class: myclass1 myclass2
:name: a-tip-reference
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn:
1. How to implement an elementwise operator using TileLang.
2. How to compile operators with dynamic shapes.
3. How TileLang addresses boundary-related issues.
4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe.
Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the [Overview](../get_started/overview.md).
## Elementwise add in TileLang
```python
def elementwise_add(N, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x):
# vector add.
for i in T.Parallel(threads):
C[b_x * threads + i] = A[b_x * threads + i] + B[b_x * threads + i]
return main
```
All logic for TileLang kernels must be implemented within the `T.Kernel(...)` scope. In this example, initializing `T.kernel(...)` requires specifying both the grid size and the number of threads per block. The returned value `bx` corresponds to `blockIdx.x` in CUDA. In the provided implementation, `T.Parallel` is used to process the data tile (of size `1 x threads`) assigned to the block for computation.
Those familiar with CUDA programming might wonder where `threadIdx` fits into this. Note that the code inside `T.Kernel` operates at the **block level**, not the **thread level**. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions.
The program can be compiled using the following code:
```python
program = elementwise_add(1024, threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
Launching the kernel is straightforward, just call it directly like a function:
```python
C = kernel(A, B)
```
The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: [example](https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py). The code for this kernel is provided below:
```python
import tilelang.language as T
def elementwise_add(
M,
N,
block_M,
block_N,
in_dtype,
out_dtype,
threads,
):
@T.prim_func
def main(
A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]
return main
```
### How to compile operators with dynamic shapes?
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
The resulting CUDA code for the kernel will include an additional `int N` parameter after the `bfloat16_t* __restrict__ A`, `bfloat16_t* __restrict__ B`, and `bfloat16_t* __restrict__ C` parameters.
### How TileLang addresses boundary-related issues.
TileLang automatically incorporates boundary-checking conditions; however, this comes at a cost. These boundary conditions may prevent TileLang from performing more advanced optimizations. I will introduce an example from the next section in advance. The corresponding code is also provided below, but note that it involves the associated CUDA code. Readers are encouraged to first review the next section before returning to this paragraph for a clearer understanding.
When compiling the example below, let's set `N` to 2047:
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
offsets = (b_x * threads + i) * num_per_thread
C[offsets + j] = A[offsets + j] + B[offsets + j]
return main
```
TileLang will generate the following CUDA code:
```c++
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
#pragma unroll
for (int i = 0; i < 8; ++i) {
if (((i * 256) + ((int)threadIdx.x)) < 2047) {
C[((i * 256) + ((int)threadIdx.x))] = (A[((i * 256) + ((int)threadIdx.x))] + B[((i * 256) + ((int)threadIdx.x))]);
}
}
}
```
We can observe that TileLang did not apply optimizations such as vectorization or coalesced memory access. In fact, except for the tail group of data, all other threads could have executed more optimized code.
## Comparison of TileLang, CUDA, and CuTe
For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity.
Typically, those new to CUDA programming often write CUDA code in a style similar to this:
```c++
// vector add
__global__ void elementwise_add(float* a, float* b, float* c, int N) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < N) {
c[idx] = a[idx] + b[idx];
}
}
```
The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process.
The CUDA code generated by TileLang for the compiled kernel can be retrieved using the `kernel.get_kernel_source()` method. Below is the CUDA code produced for the vector addition example from Section 1:
```cu
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
if (((int)threadIdx.x) < 32) {
uint4 __1;
uint4 v_ = *(uint4*)(A + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
uint4 v__1 = *(uint4*)(B + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(C + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))) = __1;
}
}
```
In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access.
While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elements—effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads.
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
offsets = (b_x * threads + i) * num_per_thread
C[offsets + j] = A[offsets + j] + B[offsets + j]
return main
```
The corresponding CUDA code generated for the above example is presented below:
```c++
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
uint4 __1;
uint4 v_ = *(uint4*)(A + (((int)threadIdx.x) * 8));
uint4 v__1 = *(uint4*)(B + (((int)threadIdx.x) * 8));
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(C + (((int)threadIdx.x) * 8)) = __1;
}
```
Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive.
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x):
A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
C_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
s_start = b_x * threads * NUM_ELE_PER_THREAD
s_end = (b_x + 1) * threads * NUM_ELE_PER_THREAD
# LDG. 128
T.copy(
A[s_start:s_end],
A_register,
)
T.copy(
B[s_start:s_end],
B_register,
)
# vector add.
for tid, i in T.Parallel(threads, NUM_ELE_PER_THREAD):
C_register[tid * NUM_ELE_PER_THREAD + i] = (
A_register[tid * NUM_ELE_PER_THREAD + i] +
B_register[tid * NUM_ELE_PER_THREAD + i])
# STG. 128
T.copy(
C_register,
C[s_start:s_end],
)
return main
```
In the example above, each thread is responsible for computing 8 elements. The `T.copy(...)` method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation.
```c++
// N is set to 8192 * 8192 when compiling
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
bfloat16_t A_register[8];
bfloat16_t B_register[8];
*(uint4*)(A_register + 0) = *(uint4*)(A + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
*(uint4*)(B_register + 0) = *(uint4*)(B + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
uint4 __1;
uint4 v_ = *(uint4*)(A_register + 0);
uint4 v__1 = *(uint4*)(B_register + 0);
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(A_register + 0) = __1;
*(uint4*)(C + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(A_register + 0);
}
```
We observed the emergence of two additional registers, `A_register` and `B_register`. However, during the actual computation, these registers are simply reassigned to `v_` and `v__1`, respectively.
To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below:
```c++
template<int NUM_ELE_PER_THREAD=8>
__global__ void elementwise_add(nv_bfloat16* C,
const nv_bfloat16* A,
const nv_bfloat16* B,
int N) {
using namespace cute;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
Tensor t_C = make_tensor(make_gmem_ptr(C), make_shape(N));
Tensor t_A = make_tensor(make_gmem_ptr(A), make_shape(N));
Tensor t_B = make_tensor(make_gmem_ptr(B), make_shape(N));
Tensor t_C_tile = local_tile(t_C, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor t_A_tile = local_tile(t_A, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor t_B_tile = local_tile(t_B, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor reg_buffer_A = make_tensor_like(t_A_tile);
Tensor reg_buffer_B = make_tensor_like(t_B_tile);
Tensor reg_buffer_C = make_tensor_like(t_C_tile);
// LDG. 128
copy(t_A_tile, reg_buffer_A);
copy(t_B_tile, reg_buffer_B);
auto reg_C_vector = recast<nv_bfloat162>(reg_buffer_C);
auto reg_A_vector = recast<nv_bfloat162>(reg_buffer_A);
auto reg_B_vector = recast<nv_bfloat162>(reg_buffer_B);
// Perform vectorized addition
#pragma unroll
for (int vec_idx = 0; vec_idx < size(reg_C_vector); ++vec_idx) {
reg_C_vector(vec_idx) = reg_A_vector(vec_idx) + reg_B_vector(vec_idx);
}
auto reg_C_flat = recast<nv_bfloat16>(reg_C_vector);
// STG. 128
copy(reg_C_flat, t_C_tile);
}
```
## Conclusion
This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels.
---
**Reference:**
[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999
# General Matrix-Vector Multiplication (GEMV)
===========================================
<div style="text-align: left;">
<em>Contributor: </em> <a href="https://github.com/botbw">@botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
Example code can be found at `examples/gemv/example_gemv.py`.
:::
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
## Triton Implementation
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
A simple Triton kernel for GEMV might look like this:
```python
@triton.jit
def _gemv_naive(
x_ptr, A_ptr, y_ptr,
N, K,
BLOCK_SIZE_K: tl.constexpr,
):
n = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_SIZE_K)
mask = offs_k < K
a_ptrs = A_ptr + n * K + offs_k
a_vals = tl.load(a_ptrs, mask=mask, other=0.0)
x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0)
dot = tl.sum(a_vals * x_vals, axis=0)
tl.store(y_ptr + n, dot)
```
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
## Naive Implementation in TileLang
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
```python
def naive_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
A_shared = T.alloc_shared((BLOCK_K,), dtype)
B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype)
C_reg = T.alloc_local((1,), accum_dtype)
T.clear(C_reg)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for tk in T.serial(BLOCK_K):
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
```
And your kernel will be compiled into CUDA by `TileLang` (in `~/.tilelang/cache`):
```C++
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_reg[1];
__shared__ uint64_t _mbarrier[2];
if (((int)threadIdx.x) == 0) {
tl::mbarrier_init(_mbarrier[0], 128);
tl::mbarrier_init(_mbarrier[1], 128);
}
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
for (int bk = 0; bk < 8; ++bk) {
tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1));
for (int tk = 0; tk < 128; ++tk) {
((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)];
((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)];
}
tl::fence_proxy_async();
tl::mbarrier_cp_async_arrive(_mbarrier[0]);
tl::mbarrier_arrive(_mbarrier[0]);
}
} else {
tl::warpgroup_reg_alloc<240>();
C_reg[0] = 0.000000e+00f;
for (int bk_1 = 0; bk_1 < 8; ++bk_1) {
tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1));
for (int tk_1 = 0; tk_1 < 128; ++tk_1) {
C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)])));
}
tl::fence_proxy_async();
tl::mbarrier_arrive(_mbarrier[1]);
}
C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]);
}
}
```
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
## More Concurrency
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
Here’s a simplified version:
```python
def naive_splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((1,), dtype)
B_local = T.alloc_local((1,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
A_local[0] = A[bk * BLOCK_K + tk]
B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
### Customizing Parallelism in K Dimension
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
```python
def splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.serial(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
## Vectorized Reads
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
```python
def splitk_gemv_vectorized(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd`
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
```python
def splitk_gemv_vectorized_tvm(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
```
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
## Autotune
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
```python
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune(
configs=get_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
```
After autotuning, now our kernel gets **~0.0067 ms**, the final generated CUDA kernel might like this:
```C++
extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
float C_accum[1];
half_t A_local[8];
half_t B_local[8];
__shared__ float red_buf0[64];
C_accum[0] = 0.000000e+00f;
for (int bk = 0; bk < 4; ++bk) {
*(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8)));
*(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8)));
for (int k = 0; k < 8; ++k) {
C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k])));
}
}
tl::fence_proxy_async();
__syncthreads();
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0];
__syncthreads();
if (((int)threadIdx.y) < 16) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]);
}
__syncthreads();
if (((int)threadIdx.y) < 8) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]);
}
__syncthreads();
if (((int)threadIdx.y) < 4) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]);
}
__syncthreads();
if (((int)threadIdx.y) < 2) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]);
}
__syncthreads();
if (((int)threadIdx.y) < 1) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]);
}
__syncthreads();
C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]);
}
```
This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically.
## Conclusion
### Benchmark Table on Hopper GPU
| Kernel Name | Latency |
|------------|------------|
| torch/cuBLAS | 0.00784 ms |
| Triton | 0.00773 ms |
| naive_gemv | 0.16607 ms |
| splitk_gemv | 0.02419 ms |
| splitk_gemv_vectorized | 0.00809 ms |
| splitk_gemv_vectorized_tvm | 0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives.
\ No newline at end of file
# General Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
:::{warning}
:class: myclass1 myclass2
:name: a-tip-reference
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
## Why Another GPU DSL?
TileLang emerged from the need for a DSL that:
1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed.
2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.).
3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning.
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
## Matrix Multiplication Example
```{figure} ../_static/img/MatmulExample.png
:alt: Matmul Example
:align: center
```
### Basic Structure
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.).
* **`T.alloc_shared(...)`** to allocate GPU shared memory.
* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation.
* **`T.Pipelined(...)`** to express software pipelining across the K dimension.
* **`T.Parallel(...)`** to parallelize data copy loops.
* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
```python
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Optional layout hints (commented out by default)
# T.annotate_layout({
# A_shared: make_mma_swizzle_layout(A_shared),
# B_shared: make_mma_swizzle_layout(B_shared),
# })
# Optional: Enabling swizzle-based rasterization
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A from global to shared memory
T.copy(A[by * block_M, ko * block_K], A_shared)
# Parallel copy tile of B from global to shared memory
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Perform a tile-level GEMM
T.gemm(A_shared, B_shared, C_local)
# Copy result from local (register fragment) to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Create the TileLang function
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
import torch
# 3. Prepare input tensors in PyTorch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# 4. Invoke the JIT-compiled kernel
c = jit_kernel(a, b)
ref_c = a @ b
# 5. Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 6. Inspect generated CUDA code (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 7. Profile performance
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
```
### Key Concepts
1. **Kernel Context**:
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...
```
- This sets up the block grid dimensions based on N/block_N and M/block_M.
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
```{figure} ../_static/img/Parallel.png
:alt: Parallel
:align: center
```
2. **Shared & Fragment Memory**:
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
- `T.alloc_shared` allocates shared memory across the entire thread block.
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
3. **Software Pipelining**:
```python
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
...
```
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
- The argument `num_stages=3` indicates the pipeline depth.
```{figure} ../_static/img/software_pipeline_inference.png
:alt: Software Pipeline Inference
:align: center
```
4. **Parallel Copy**:
```python
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
```
- `T.Parallel` marks the loop for thread-level parallelization.
- The compiler will map these loops to the available threads in the block.
5. **Tile-Level GEMM**:
```python
T.gemm(A_shared, B_shared, C_local)
```
- A single call that performs a tile-level matrix multiplication using the specified buffers.
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
6. **Copying Back Results**:
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
- After computation, data in the local register fragment is written back to global memory.
## Comparison with Other DSLs
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`).
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
## Performance on Different Platforms
```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png
:alt: Performance on Different Platforms
:align: center
```
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
- ~1.1x speedup over cuBLAS on RTX 4090
- ~0.97x on A100 (on par with cuBLAS)
- ~1.0x on H100
- ~1.04x on MI300X
- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware.
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
## Conclusion
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
1. We allocated shared memory and register fragments.
2. We pipelined the loading and computation along the K dimension.
3. We used parallel copying to efficiently load tiles from global memory.
4. We invoked `T.gemm` to dispatch a tile-level matrix multiply.
5. We verified correctness against PyTorch and examined performance.
By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs.
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
## Further Resources
* [TileLang GitHub](https://github.com/tile-ai/tilelang)
* [BitBLAS](https://github.com/tile-ai/bitblas)
* [Triton](https://github.com/openai/triton)
* [Cutlass](https://github.com/NVIDIA/cutlass)
* [PyCUDA](https://documen.tician.de/pycuda/) <!-- codespell:ignore -->
# Sparse Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/botbw">botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
This feature is still **experimental** and need further optimization.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
Example code can be found at `examples/gemm_sp`.
:::
## Structured sparsity in the NVIDIA Ampere architecture
Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
:::{warning}
This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
:::
```{figure} ../_static/img/sparse_mma_storage_example.png
:align: center
Figure: Sparse MMA storage example (from PTX doc)
```
## Compress a dense tensor
To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
```python
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
```
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp` with CUTLASS's compressor
:::{warning}
It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
:::
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
Check comments in below kernel code for required modification.
```python
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ # Annotate reordered cutlass metadata layout
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
```
Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
## `T.gemm_sp_v2` with a custom compressor
To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
Suppose we have the following row vector:
```python
t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
```
The non-zero elements and their corresponding indices are:
```python
t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
```
The corresponding uint16 metadata is:
```python
# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
# Note: the above code is not runnable in python as the interpreter won't take the binary
# as 2's complement
metadata_int16 = tensor(-29107)
```
You can decode an int16 metadata tensor using the following utility:
```python
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
```
The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
```python
@tilelang.jit(out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout: # NOTE: Make sure compressor metadata layout
T.annotate_layout({ # is same with your computation kernel
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared,
mma_dtype="float16",
arch="8.0",
block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
```
## A note on `gemm_sp` and `gemm_sp_v2`
Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
However, fixing a specific layout introduces several potential issues:
1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
\ No newline at end of file
# Installation Guide
## Installing with pip
**Prerequisites for installation via wheel or PyPI:**
- **glibc**: 2.28 (Ubuntu 20.04 or later)
- **Python Version**: >= 3.8
- **CUDA Version**: 12.0 <= CUDA < 13
The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal:
```bash
pip install tilelang
```
Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page:
```bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```
To install the latest version of tilelang from the GitHub repository, you can run the following command:
```bash
pip install git+https://github.com/tile-ai/tilelang.git
```
After installing tilelang, you can verify the installation by running:
```bash
python -c "import tilelang; print(tilelang.__version__)"
```
## Building from Source
**Prerequisites for building from source:**
- **Operating System**: Linux
- **Python Version**: >= 3.8
- **CUDA Version**: >= 10.0
If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment.
First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands:
```bash
apt-get update
apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev
```
Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process.
> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies).
```bash
git clone --recursive https://github.com/tile-ai/tilelang.git
cd tilelang
pip install . -v
```
If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation.
```bash
pip install -e . -v
```
> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below.
(working-from-source-via-pythonpath)=
### Working from Source via `PYTHONPATH` (Recommended for Developers)
If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first:
```bash
mkdir -p build
cd build
cmake .. -DUSE_CUDA=ON
make -j
```
We also recommend using `ninja` to speed up compilation:
```bash
cmake .. -DUSE_CUDA=ON -G Ninja
ninja
```
Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example:
```bash
export PYTHONPATH=/path/to/tilelang:$PYTHONPATH
python -c "import tilelang; print(tilelang.__version__)"
```
Some useful CMake options you can toggle while configuring:
- `-DUSE_CUDA=ON|OFF` builds against NVIDIA CUDA (default ON when CUDA headers are found).
- `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs.
- `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`.
(using-existing-tvm)=
### Building with Customized TVM Path
If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang:
```bash
TVM_ROOT=<your-tvm-repo> pip install . -v
```
> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly.
(install-using-docker)=
## Install Using Docker
For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems.
**Prerequisites:**
- Docker installed on your system
- NVIDIA Docker runtime or GPU is not necessary for building tilelang, you can build on a host without GPU and use that built image on other machine.
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
2. **Build Docker Image**:
Navigate to the docker directory and build the image for your desired CUDA version:
```bash
cd docker
docker build -f Dockerfile.cu120 -t tilelang-cu120 .
```
Available Dockerfiles:
- `Dockerfile.cu120` - For CUDA 12.0
- Other CUDA versions may be available in the docker directory
3. **Run Docker Container**:
Start the container with GPU access and volume mounting:
```bash
docker run -itd \
--shm-size 32g \
--gpus all \
-v /home/tilelang:/home/tilelang \
--name tilelang_b200 \
tilelang-cu120 \
/bin/zsh
```
**Command Parameters Explanation:**
- `--shm-size 32g`: Increases shared memory size for better performance
- `--gpus all`: Enables access to all available GPUs
- `-v /home/tilelang:/home/tilelang`: Mounts host directory to container (adjust path as needed)
- `--name tilelang_b200`: Assigns a name to the container for easy management
- `/bin/zsh`: Uses zsh as the default shell
4. **Access the Container and Verify Installation**:
```bash
docker exec -it tilelang_b200 /bin/zsh
# Inside the container:
python -c "import tilelang; print(tilelang.__version__)"
```
## Install with Nightly Version
For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang.
```bash
pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/
```
> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.
## Install Configs
### Build-time environment variables
`USE_CUDA`: If to enable CUDA support, default: `ON` on Linux, set to `OFF` to build a CPU version. By default, we'll use `/usr/local/cuda` for building tilelang. Set `CUDAToolkit_ROOT` to use different cuda toolkit.
`USE_ROCM`: If to enable ROCm support, default: `OFF`. If your ROCm SDK does not located in `/opt/rocm`, set `USE_ROCM=<rocm_sdk>` to enable build ROCm against custom sdk path.
`USE_METAL`: If to enable Metal support, default: `ON` on Darwin.
`TVM_ROOT`: TVM source root to use.
`NO_VERSION_LABEL` and `NO_TOOLCHAIN_VERSION`:
When building tilelang, we'll try to embed SDK and version information into package version as below,
where local version label could look like `<sdk>.git<git_hash>`. Set `NO_VERSION_LABEL=ON` to disable this behavior.
```
$ python -mbuild -w
...
Successfully built tilelang-0.1.6.post1+cu116.git0d4a74be-cp38-abi3-linux_x86_64.whl
```
where `<sdk>={cuda,rocm,metal}`. Specifically, when `<sdk>=cuda` and `CUDA_VERSION` is provided via env,
`<sdk>=cu<cuda_major><cuda_minor>`, similar with this part in pytorch.
Set `NO_TOOLCHAIN_VERSION=ON` to disable this.
### Run-time environment variables
Please refer to the `env.py` file for a full list of supported run-time environment variables.
## Other Tips
### IDE Configs
Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir.
VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration.
### Compile Cache
The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found.
### Repairing Wheels
If you plan to use your wheel in other environment,
it's recommended to use auditwheel (on Linux) or delocate (on Darwin)
to repair them.
(faster-rebuild-for-developers)=
### Faster Rebuild for Developers
`pip install` introduces extra [un]packaging and takes ~30 sec to complete,
even if no source change.
Developers who needs to recompile frequently could use:
```bash
pip install -r requirements-dev.txt
# For first time compilation
pip install -e . -v --no-build-isolation
# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly.
mkdir build
cd build
cmake .. -G Ninja
ninja
# Rebuild when you change the cpp code
cd build; ninja
```
When running in editable/developer mode,
you'll see logs like below:
```console
$ python -c 'import tilelang'
2025-10-14 11:11:29 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /Users/yyc/repo/tilelang/build
```
# The Tile Language: A Brief Introduction
## Programming Interface
The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfaces—targeted at **Beginner**, **Developer**, and **Expert** users—that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
## Programming Interfaces
1. **Beginner Level (Hardware-Unaware)**
- Intended for users who need to write code that is independent of specific hardware details.
- The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations.
- *Note:* This interface is not yet fully implemented.
2. **Developer Level (Hardware-Aware with Tile Library)**
- Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations.
- Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures.
- Users at this level can leverage these ready-made primitives without diving into low-level threading details.
3. **Expert Level (Hardware-Aware with Thread Primitives)**
- For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing).
- Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels.
- This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures.
## Compilation Flow
1. **Tile Program**
A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives.
2. **Tile Program with Tile Library**
When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations.
3. **Tile Program with Thread Primitives**
Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage.
4. **IRModule**
After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details.
5. **Source Code Generation (C/CUDA/HIP/LLVM/…)**
From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD).
6. **Hardware-Specific Executable/Runtime**
Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures.
## Tile-based Programming Model
[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``,
illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining,
and operator calls to manage data movement and computation with fine-grained control.
In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling
leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization
and reduce latency.
Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang``
allows developers to reason about performance-critical optimizations within a user-friendly programming model.
```{figure} ../_static/img/MatmulExample.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
:name: fig-overview-gemm
Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``.
```
### Tile declarations
At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block.
### Explicit Hardware Memory Allocation
A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular:
- `T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance.
- `T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections.
Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections.
```{figure} ../_static/img/LayoutInference.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
```
# Understanding Targets
TileLang is built on top of TVM, which relies on **targets** to describe the device you want to compile for.
The target determines which code generator is used (CUDA, HIP, Metal, LLVM, …) and allows you to pass
device-specific options such as GPU architecture flags. This page summarises how to pick and customise a target
when compiling TileLang programs.
## Common target strings
TileLang ships with a small set of common targets; each accepts the full range of TVM options so you can fine-tune
the generated code. The most frequent choices are listed below:
| Base name | Description |
| --------- | ----------- |
| `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. |
| `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. |
| `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. |
| `metal` | Apple Silicon GPUs (arm64 Macs). |
| `llvm` | CPU execution; accepts the standard TVM LLVM switches. |
| `webgpu` | Browser / WebGPU runtimes. |
| `c` | Emit plain C source for inspection or custom toolchains. |
To add options, append them after the base name, separated by spaces. For example:
```python
target = "cuda -arch=sm_90"
kernel = tilelang.compile(func, target=target, execution_backend="cython")
# or
@tilelang.jit(target=target)
def compiled_kernel(*args):
return func(*args)
```
The same convention works for HIP or LLVM (e.g. `hip -mcpu=gfx940`, `llvm -mtriple=x86_64-linux-gnu`).
### Advanced: Specify Exact Hardware
When you already know the precise GPU model, you can encode it in the target string—either via `-arch=sm_XX` or by
using one of TVM’s pre-defined target tags such as `nvidia/nvidia-h100`. Supplying this detail is optional for
TileLang in general use, but it becomes valuable when the TVM cost model is enabled (e.g. during autotuning). The
cost model uses the extra attributes to make better scheduling predictions. If you skip this step (or do not use the
cost model), generic targets like `cuda` or `auto` are perfectly fine.
All CUDA compute capabilities recognised by TVM’s target registry are listed below. Pick the one that matches your
GPU and append it to the target string or use the corresponding target tag—for example `nvidia/nvidia-a100`.
| Architecture | GPUs (examples) |
| ------------ | ---------------- |
| `sm_20` | `nvidia/tesla-c2050`, `nvidia/tesla-c2070` |
| `sm_21` | `nvidia/nvs-5400m`, `nvidia/geforce-gt-520` |
| `sm_30` | `nvidia/quadro-k5000`, `nvidia/geforce-gtx-780m` |
| `sm_35` | `nvidia/tesla-k40`, `nvidia/quadro-k6000` |
| `sm_37` | `nvidia/tesla-k80` |
| `sm_50` | `nvidia/quadro-k2200`, `nvidia/geforce-gtx-950m` |
| `sm_52` | `nvidia/tesla-m40`, `nvidia/geforce-gtx-980` |
| `sm_53` | `nvidia/jetson-tx1`, `nvidia/jetson-nano` |
| `sm_60` | `nvidia/tesla-p100`, `nvidia/quadro-gp100` |
| `sm_61` | `nvidia/tesla-p4`, `nvidia/quadro-p6000`, `nvidia/geforce-gtx-1080` |
| `sm_62` | `nvidia/jetson-tx2` |
| `sm_70` | `nvidia/nvidia-v100`, `nvidia/quadro-gv100` |
| `sm_72` | `nvidia/jetson-agx-xavier` |
| `sm_75` | `nvidia/nvidia-t4`, `nvidia/quadro-rtx-8000`, `nvidia/geforce-rtx-2080` |
| `sm_80` | `nvidia/nvidia-a100`, `nvidia/nvidia-a30` |
| `sm_86` | `nvidia/nvidia-a40`, `nvidia/nvidia-a10`, `nvidia/geforce-rtx-3090` |
| `sm_87` | `nvidia/jetson-agx-orin-32gb`, `nvidia/jetson-agx-orin-64gb` |
| `sm_89` | `nvidia/geforce-rtx-4090` |
| `sm_90a` | `nvidia/nvidia-h100` (DPX profile) |
| `sm_100a` | `nvidia/nvidia-b100` |
Refer to NVIDIA’s [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) page or the TVM source
(`3rdparty/tvm/src/target/tag.cc`) for the latest mapping between devices and compute capabilities.
## Creating targets programmatically
If you prefer working with TVM’s `Target` objects, TileLang exposes the helper
`tilelang.utils.target.determine_target` (returns a canonical target string by default, or the `Target`
object when `return_object=True`):
```python
from tilelang.utils.target import determine_target
tvm_target = determine_target("cuda -arch=sm_80", return_object=True)
kernel = tilelang.compile(func, target=tvm_target)
```
You can also build targets directly through TVM:
```python
from tvm.target import Target
target = Target("cuda", host="llvm")
target = target.with_host(Target("llvm -mcpu=skylake"))
```
TileLang accepts either `str` or `Target` inputs; internally they are normalised and cached using the canonical
string representation. **In user code we strongly recommend passing target strings rather than
`tvm.target.Target` instances—strings keep cache keys compact and deterministic across runs, whereas constructing
fresh `Target` objects may lead to slightly higher hashing overhead or inconsistent identity semantics.**
## Discovering supported targets in code
Looking for a quick reminder of the built-in base names and their descriptions? Use:
```python
from tilelang.utils.target import describe_supported_targets
for name, doc in describe_supported_targets().items():
print(f"{name:>6}: {doc}")
```
This helper mirrors the table above and is safe to call at runtime (for example when validating CLI arguments).
## Troubleshooting tips
- If you see `Target cuda -arch=sm_80 is not supported`, double-check the spellings and that the option is valid for
TVM. Any invalid switch will surface as a target-construction error.
- Runtime errors such as “no kernel image is available” usually mean the `-arch` flag does not match the GPU you are
running on. Try dropping the flag or switching to the correct compute capability.
- When targeting multiple environments, use `auto` for convenience and override with an explicit string only when
you need architecture-specific tuning.
# 👋 Welcome to Tile Language
[GitHub](https://github.com/tile-ai/tilelang)
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
low-level optimizations necessary for state-of-the-art performance.
:::{toctree}
:maxdepth: 2
:caption: GET STARTED
get_started/Installation
get_started/overview
get_started/targets
:::
:::{toctree}
:maxdepth: 1
:caption: TUTORIALS
tutorials/debug_tools_for_tilelang
tutorials/auto_tuning
tutorials/logging
:::
:::{toctree}
:maxdepth: 1
:caption: PROGRAMMING GUIDES
programming_guides/overview
programming_guides/language_basics
programming_guides/instructions
programming_guides/control_flow
programming_guides/autotuning
programming_guides/type_system
:::
:::{toctree}
:maxdepth: 1
:caption: DEEP LEARNING OPERATORS
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
deeplearning_operators/matmul_sparse
deeplearning_operators/deepseek_mla
:::
:::{toctree}
:maxdepth: 1
:caption: COMPILER INTERNALS
compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
compiler_internals/tensor_checks
:::
:::{toctree}
:maxdepth: 1
:caption: API Reference
autoapi/tilelang/index
:::
:::{toctree}
:maxdepth: 1
:caption: Privacy
privacy
:::
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
# Privacy
All data stays in users' device and is not collected by the app.
# Autotuning
TileLang includes a built‑in autotuner that searches configuration spaces
for the best performing kernel, compiles candidates in parallel, validates
correctness, benchmarks them, and caches the best result for reuse.
This guide covers two workflows:
- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit`
- Programmatic: `AutoTuner.from_kernel(...).set_*().run()`
It also explains input tensor supply, validation, caching, and environment
variables that affect parallelism and cache behavior.
## 1) Decorator‑based Autotune
Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as
function arguments with defaults. The autotuner overrides these parameters with
values from your config space.
```python
import tilelang
import tilelang.language as T
def matmul_configs(M, N, K):
# Example space — tailor to your target
tiles = [64, 128]
stages = [2, 3]
threads = [128, 256]
return [
dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH)
for BM in tiles
for BN in tiles
for BK in [32, 64]
for S in stages
for TH in threads
]
@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60)
@tilelang.jit(out_idx=[-1])
def matmul(M: int, N: int, K: int,
block_M: int = 128, block_N: int = 128, block_K: int = 32,
threads: int = 128, num_stages: int = 3,
dtype: str = 'float16', accum_dtype: str = 'float32'):
@T.prim_func
def kernel(A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_s = T.alloc_shared((block_M, block_K), dtype)
B_s = T.alloc_shared((block_K, block_N), dtype)
C_f = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, ko * block_K], A_s)
T.copy(B[ko * block_K, bx * block_N], B_s)
T.gemm(A_s, B_s, C_f)
T.copy(C_f, C[by * block_M, bx * block_N])
return kernel
# Usage
# Provide inputs via context (recommended for reproducibility across configs)
import torch
M = N = K = 1024
A = torch.randn(M, K, device='cuda', dtype=torch.float16)
B = torch.randn(K, N, device='cuda', dtype=torch.float16)
C = torch.empty(M, N, device='cuda', dtype=torch.float16)
from tilelang.autotuner import set_autotune_inputs
with set_autotune_inputs(A, B, C):
tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel
tuned_kernel(A, B, C) # run best kernel
```
Notes
- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each
dict’s keys must match the tunable function arguments (e.g., `block_M`).
- The decorator returns a callable that runs autotune once per argument tuple
and caches the resulting best kernel in‑process.
- For explicit input control during tuning, wrap the call with
`set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used.
## 2) Programmatic Autotune
Use the `AutoTuner` class to manage configs and arguments more explicitly.
```python
from tilelang.autotuner import AutoTuner
kernel_factory = matmul # the function above (already @tilelang.jit)
tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K))
tuner.set_profile_args(
warmup=25, rep=100, timeout=60,
supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog
ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2),
)
tuner.set_compile_args(
target='auto', # or 'cuda'/'hip'/'metal'
execution_backend='auto', # resolves per-target
out_idx=[-1], # which outputs to return if multiple
pass_configs={ # optional TVM passes/flags
# tilelang.PassConfigKey.EXAMPLE_KEY: value,
},
)
artifact = tuner.run() # compiles + runs + validates all configs
best_kernel = artifact.kernel # JITKernel
best_latency = artifact.latency
best_config = artifact.config
# Reuse best kernel
best_kernel(A, B, C)
```
### Example Gallery (in repo)
- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs
- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune`
- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler`
- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup
- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes
- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent
Click any path to open the code and compare patterns.
## Input Tensor Supply
The tuner needs inputs to compile and benchmark kernels. Provide them in one of
three ways (priority order):
1) Context manager (fixed inputs across configs)
```python
with set_autotune_inputs(A, B, C):
tuned = matmul(M, N, K)
```
2) Custom supplier program
```python
def supply_prog(signature):
# signature holds KernelParam objects describing shapes/dtypes
# Return a list of torch tensors matching the kernel’s arguments
return [A, B, C]
tuner.set_profile_args(supply_prog=supply_prog)
```
3) Built‑in generators via `supply_type`
- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges)
- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One`
Important
- Built‑in generators require static shapes; if your PrimFunc uses symbolic
dimensions (T.dyn), supply concrete inputs via (1) or (2).
- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support.
## Correctness Checking and Tolerances
Use one of the following validation methods:
- `ref_prog`: Provide a reference program that receives the same inputs and
checks results. You can return a boolean or raise on mismatch.
- `manual_check_prog`: A callable that inspects outputs and raises on mismatch.
- `skip_check=True`: Skip correctness checks (faster, use with caution).
Control numeric drift via:
- `rtol` and `atol` (defaults 1e‑2)
- `max_mismatched_ratio` (default 1%)
## Configuration Spaces and Best Practices
What to tune
- Tile sizes: `block_M`, `block_N`, `block_K`
- Software pipelining: `num_stages`
- Threads per block: `threads` (or (x, y) tuple)
- Optional: dtype variants, epilogues, small scheduling knobs
Tips
- Start from a working baseline. Tune a small, meaningful space first.
- Respect hardware limits (shared memory bytes, registers per thread/block,
max threads per block). Eliminate impossible configs up‑front.
- Keep block sizes multiples of vector widths and warp sizes when relevant.
- Use `set_autotune_inputs` to ensure each config is measured on identical data.
- Record your best configs and bake them as defaults when stable.
## Parallel Compilation/Benchmarking and Timeouts
The tuner compiles configurations in parallel using a thread pool and benchmarks
them with a per‑config timeout. On CUDA, each worker sets the current device to
avoid context issues.
Notes
- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect.
- Logs are written to `autotuner.log` in the working directory.
## Caching
The autotuner caches best artifacts both in‑memory (per process) and on disk under
`$TILELANG_CACHE_DIR/autotuner`. The cache key includes:
- TileLang version, function source, closure free‑vars
- Config list, compile args, profile args
Disk cache contents (per key)
- Best config and latency: `best_config.json`, `latency.json`
- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend)
- Function and params: `function.pkl`, `params.pkl`
Control via env vars (tilelang.env)
- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`)
- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`)
- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1`
- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1`
CPU worker control
- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9)
- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto)
- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited)
Backend notes
- NVRTC backend persists `.cubin` and a Python launcher.
- Torch/DLPack backend may not save artifacts to disk; in this case, only
in‑memory caching applies and a warning is logged.
## Alternative: Manual Sweeps with par_compile
If you prefer manual control, use `JITImpl.par_compile` to compile a batch of
configs and drive your own benchmarking:
```python
@tilelang.jit
def factory(M, N, K, block_M=128, block_N=128, block_K=32):
@T.prim_func
def k(A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16')):
...
return k
impl = factory # JITImpl
cfgs = [
dict(block_M=64, block_N=128, block_K=32),
dict(block_M=128, block_N=128, block_K=64),
]
kernels = impl.par_compile(cfgs, num_workers=4)
# Now benchmark kernels[i](A, B, C) yourself
```
## Recording and Reusing Best Configs
The programmatic path returns an `AutotuneResult` that can be saved and later
reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs.
```python
artifact = tuner.run() # AutotuneResult
# Save to disk
from pathlib import Path
save_dir = Path('out/best/matmul_1024')
artifact.save_to_disk(save_dir, verbose=True)
# Reload later
from tilelang.autotuner.param import AutotuneResult, CompileArgs
restored = AutotuneResult.load_from_disk(save_dir, CompileArgs())
best = restored.kernel
best(A, B, C)
```
Notes
- DLPack/Torch execution backend may not persist compiled binaries; in that
case, re‑compilation is needed on load or use a different backend.
- The directory contains human‑readable JSONs (best config/latency) and sources.
## Advanced: Config Space Callables
Derive config spaces from problem sizes to keep searches targeted and legal:
```python
def matmul_configs(M, N, K):
large = min(M, N, K) >= 1024
tiles = [128] if large else [64, 128]
for BM in tiles:
for BN in tiles:
for BK in [32, 64]:
for S in [2, 3]:
for TH in [128, 256]:
yield dict(block_M=BM, block_N=BN, block_K=BK,
num_stages=S, threads=TH)
```
## Device and Backend Selection
Tune compile‑time options explicitly:
- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target)
- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'`
- `pass_configs={...}` to toggle TileLang/TVM passes for experiments
On CUDA with multiple GPUs, the tuner sets the current device per worker thread
to avoid context mixups.
## Troubleshooting
- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable.
- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that
your reference check isn’t the bottleneck.
- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom
`supply_prog`.
- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend.
# Control Flow
This guide covers the control‑flow primitives in TileLang and how they lower to
efficient GPU code. You will use these to structure loops, handle boundaries,
and express pipelined compute.
## Overview
- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`)
- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`
- While loops: `while` with a TIR condition
- Flow control: Python `break` / `continue`
- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass
The examples assume `import tilelang.language as T`.
## Conditionals
Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels.
Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are
treated as compile‑time constants and will be folded.
```python
for i in T.serial(N):
if i < N: # TIR condition
C[i] = A[i] + B[i]
else:
pass
# Ternary
x = (A[i] if i < N else 0)
```
Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use
`T.any_of` / `T.all_of` for clarity:
```python
if T.all_of(i < M, j < N):
C[i, j] = A[i, j] + B[i, j]
```
Boundary handling note
- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access
may be out‑of‑bounds, and elides them when proven safe. You can often omit
explicit `if` checks for simple edge handling, but keep them when you need
custom logic or clarity.
## Loops
### Serial
`T.serial` creates a plain for‑loop. Common forms:
```python
for i in T.serial(N):
... # 0..N-1
for i in T.serial(0, N, 2):
... # 0, 2, 4, ...
```
### Unroll
`T.unroll` requests loop unrolling for small trip counts.
```python
for k in T.unroll(K_TILE):
acc += a[k] * b[k]
```
Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are
available for expert tuning.
### Parallel (elementwise)
`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise
operations. The body receives all indices in one `for` header:
```python
for i, j in T.Parallel(M, N):
C[i, j] = A[i, j] + B[i, j]
```
Optional: `coalesced_width=` can hint memory coalescing for the innermost loop.
### Pipelined (software pipelining)
`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g.,
Global→Shared copies with compute). This is the backbone of GEMM/attention
pipelines.
```python
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile
T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile
T.gemm(A_s, B_s, C_f) # stage: compute
```
### Persistent (advanced)
`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent
thread‑block style looping. It is an advanced construct that TileLang lowers in
later passes and is typically used by specialized templates.
## While Loops
`while` is supported when the condition is a TIR expression. Avoid infinite
loops; TileLang will error if it detects a constant‑true condition.
```python
i = 0
while i < N:
...
if done:
break
i += 1
```
## Break and Continue
Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/
`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for
readability; the compiler will ignore the dead path.
## Putting It Together: Residual Tile Handling
Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess,
the explicit guard can be omitted when you don’t need a custom edge path.
```python
for i, j in T.Parallel(M, N):
gi = by * BM + i
gj = bx * BN + j
if T.all_of(gi < M, gj < N): # optional in many cases
C[gi, gj] = A[gi, gj] + B[gi, gj]
```
## Debugging Conditions
Use `T.print` to inspect values under predicates. For buffers, TileLang prints
from a single thread to avoid duplicate outputs.
```python
if i == 0:
T.print(C, msg='C tile:')
```
# Instructions
This page summarizes the core TileLang “instructions” available at the DSL
level, how they map to hardware concepts, and how to use them correctly.
## Quick Categories
- Data movement: `T.copy`, `T.c2d_im2col`, staging Global ↔ Shared ↔ Fragment
- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`),
reductions (`T.reduce_sum`, `T.cumsum`, warp reducers)
- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view`
- Diagnostics: `T.print`, `T.device_assert`
- Advanced: atomics, memory barriers, warp‑group ops
## Data Movement
Use `T.copy(src, dst, coalesced_width=None, disable_tma=False, eviction_policy=None)`
to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or
`BufferRegion`; extents are inferred or broadcast when possible.
```python
# Global → Shared tiles (extents inferred from dst)
T.copy(A[by * BM, ko * BK], A_s)
T.copy(B[ko * BK, bx * BN], B_s)
# Fragment/Register → Global (store result)
T.copy(C_f, C[by * BM, bx * BN])
```
Semantics
- Extents are deduced from arguments; missing sides broadcast to the other’s rank.
- Access patterns are legalized and coalesced during lowering. Explicit
vectorization is not required in HL mode.
- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an
access may be out‑of‑bounds and drops them when proven safe.
Other helpers
- `T.c2d_im2col(img, col, ...)`: convenience for conv‑style transforms.
## Compute Primitives
GEMM and sparse GEMM
- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared
inputs and a fragment accumulator; lowered to target‑specific tensor cores.
- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README).
Reductions and scans
- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp
reducers (`T.warp_reduce_sum`, etc.).
- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or
`T.fill`.
Elementwise math
- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`,
`T.sigmoid`, etc. Compose freely inside loops.
Reshape/view (no copy)
- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create
new views that share storage, with shape/dtype checks enforced.
## Synchronization (HL usage)
In HL pipelines, you usually don’t need to write explicit barriers. Passes such
as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate
producer/consumer ordering and thread synchronization behind the scenes.
If you need debugging or explicit checks:
- `T.device_assert(cond, msg='')` emits device‑side asserts on CUDA targets.
- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread.
## Putting It Together: GEMM Tile
```python
@T.prim_func
def gemm(
A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16'),
):
with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
A_s = T.alloc_shared((BM, BK), 'float16')
B_s = T.alloc_shared((BK, BN), 'float16')
C_f = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s) # Global → Shared
T.copy(B[ko * BK, bx * BN], B_s)
T.gemm(A_s, B_s, C_f) # compute into fragment
T.copy(C_f, C[by * BM, bx * BN]) # store back
```
## Instruction Reference (Concise)
Below is a concise list of TileLang instructions grouped by category. For full
signatures, behaviors, constraints, and examples, refer to API Reference
(`autoapi/tilelang/index`).
Data movement
- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment.
- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv.
Memory allocation and descriptors
- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer.
- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment.
- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem).
- `T.alloc_barrier(arrive_count)`: Shared barrier buffer.
- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+).
- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf.
- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator.
- `T.alloc_wgmma_desc(dtype='uint64')`
- `T.alloc_tcgen05_smem_desc(dtype='uint64')`
- `T.alloc_tcgen05_instr_desc(dtype='uint32')`
- `T.empty(shape, dtype='float32')`: Declare function output tensors.
Compute primitives
- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator.
- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM.
- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`.
- Scans: `T.cumsum`, finalize: `T.finalize_reducer`.
- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`.
- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...).
- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`.
- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding).
- Helpers: `T.clear(buf)`, `T.fill(buf, value)`.
- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`.
Diagnostics
- `T.print(obj, msg='')`: Print scalar/buffer from one thread.
- `T.device_assert(cond, msg='')`: Device-side assert (CUDA).
Logical helpers
- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates.
Annotation helpers
- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint.
- `T.annotate_layout({...})`: Attach explicit layouts to buffers.
- `T.annotate_safe_value(var, ...)`: Safety/const hints.
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.
Atomics
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`.
- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`.
- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`.
Custom intrinsics
- `T.dp4a(A, B, C)`: 4‑element dot‑product accumulate.
- `T.clamp(x, lo, hi)`: Clamp to [lo, hi].
- `T.loop_break()`: Break from current loop via intrinsic.
Barriers, TMA, warp‑group
- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`.
- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`.
- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`.
- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`,
`T.tma_store_arrive(...)`, `T.tma_store_wait(...)`.
- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`.
- Warp‑group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`,
`T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`.
Lane/warp index
- `T.get_lane_idx(warp_size=None)`: Lane id in warp.
- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync).
- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync).
- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id.
Register control
- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`.
- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`.
- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`.
## Notes on Dtypes
Dtypes accept three equivalent forms:
- String: `'float32'`
- TileLang dtype: `T.float32`
- Framework dtype: `torch.float32`
All are normalized internally. See Type System for details.
# Language Basics
This page introduces the core TileLang (tile‑lang) DSL that you’ll use to write
high‑performance kernels. It focuses on how to define a kernel, express
iteration, move data across memory scopes, and run it with JIT.
The examples use the conventional aliases:
```python
import tilelang
import tilelang.language as T
from tilelang import jit
```
## 1. Defining a Kernel with `@T.prim_func`
TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func`
decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or
`T.Buffer`.
Note on dtypes
- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`),
or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these.
See Type System for details.
```python
@T.prim_func
def add_kernel(
A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
... # kernel body
```
- Shapes may be concrete integers or symbolic. For symbolic, you can pass
Python ints through the outer `@jit` wrapper (shown below), or annotate with
`T.dyn` when you want a named symbolic dimension.
```python
# Named symbolic dimension (optional)
K = T.dyn['K']
@T.prim_func
def uses_dyn(A: T.Tensor((K,), 'float32')):
...
```
### Dynamic symbolic dimensions: two ways
TileLang supports two complementary ways to introduce symbolic (dynamic) dims:
- Type-level annotations via `T.dyn[...]` (recommended for function signatures)
- Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above).
- Inside the kernel body, prefer reading from the buffer’s shape, e.g. `M = A.shape[0]`.
- Term-level variables via `T.dynamic(name, dtype)`
- Creates a TIR `tir.Var` you can use directly in expressions/loops.
- Handy when you need to reference the dimension symbol in the body.
```python
# 1) Annotation-only symbol; read the bound size via shape
K = T.dyn['K'] # dtype defaults to int32
@T.prim_func
def foo(A: T.Tensor((K,), 'float32')):
N = A.shape[0]
for i in T.serial(N):
...
# 2) Explicit Var symbol usable in the body
K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32
@T.prim_func
def bar(A: T.Tensor((K,), 'float32')):
for i in T.serial(K):
...
```
Notes
- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`.
- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call.
- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes.
## 2. Launching Work with `T.Kernel`
`with T.Kernel(...)` declares a launch context and creates block/thread
bindings. For GPU backends, specify a grid and threads per block.
```python
with T.Kernel(grid_x, grid_y, threads=128) as (bx, by):
... # bx/by are blockIdx.x/y
```
You rarely need raw thread indices; most kernels use structured loops
(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`.
## 3. Loops and Control Flow
Core loop constructs map to familiar hardware patterns:
- `T.serial(start, stop[, step])`: plain for‑loop
- `T.unroll(start, stop[, step])`: unrolled loop
- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwise‑friendly)
- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer
```python
for i in T.serial(N):
...
for i, j in T.Parallel(M, N):
C[i, j] = A[i, j] + B[i, j]
for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
# overlap copy/compute across stages
...
```
Conditionals use standard Python `if`/`else`. Guard edges with predicates when
tile sizes do not divide problem sizes evenly.
## 4. Memory Scopes and Allocation
TileLang exposes key software‑managed scopes:
- Global: device memory (default for `T.Tensor` arguments)
- Shared: on‑chip, block‑visible (`T.alloc_shared(shape, dtype)`)
- Fragment and scalars: per‑thread fragments and scalar vars but in Shared View
(`T.alloc_fragment`, `T.alloc_var`)
```python
A_shared = T.alloc_shared((BM, BK), 'float16')
B_shared = T.alloc_shared((BK, BN), 'float16')
C_local = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_local) # zero accumulators
```
## 5. Moving Data: `T.copy`
Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers,
buffer regions, or buffer loads; extents are inferred or can be broadcast.
```python
# Global -> Shared (tile copy), extents inferred from dst
T.copy(A[by * BM, ko * BK], A_shared)
T.copy(B[ko * BK, bx * BN], B_shared)
# Fragment -> Global (store back)
T.copy(C_local, C[by * BM, bx * BN])
```
`T.copy` performs coalescing and scope‑specific lowering during compilation.
## 6. A Minimal End‑to‑End Example (Vector Add)
```python
import tilelang
import tilelang.language as T
from tilelang import jit
@jit # infers target from tensors at first call
def add(N: int, block: int = 256, dtype: str = 'float32'):
@T.prim_func
def add_kernel(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block), threads=block) as bx:
for i in T.Parallel(block):
gi = bx * block + i
# Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB
C[gi] = A[gi] + B[gi]
return add_kernel
# Host side (PyTorch shown; NumPy/DLPack also supported)
import torch
N = 1 << 20
A = torch.randn(N, device='cuda', dtype=torch.float32)
B = torch.randn(N, device='cuda', dtype=torch.float32)
C = torch.empty(N, device='cuda', dtype=torch.float32)
kernel = add(N)
kernel(A, B, C) # runs on GPU
torch.testing.assert_close(C, A + B)
```
Notes
- The `@jit` wrapper returns a callable kernel after the first compilation.
- You can pass compile‑time tunables (tile sizes, dtypes) through the outer
Python function and bake them into the generated TIR.
## 7. Tiled GEMM Skeleton
Below is a minimal pattern for a tiled GEMM using shared memory staging and a
fragment accumulator. It mirrors the quickstart style found in the repository.
```python
@T.prim_func
def gemm(
A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16'),
):
with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
A_s = T.alloc_shared((BM, BK), 'float16')
B_s = T.alloc_shared((BK, BN), 'float16')
C_f = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s)
T.copy(B[ko * BK, bx * BN], B_s)
T.gemm(A_s, B_s, C_f) # lowered to tensor‑core/ISA specific kernels
T.copy(C_f, C[by * BM, bx * BN])
```
## 8. Debugging and Printing
Use `T.print` inside a kernel for quick introspection. TileLang emits printing
from a single thread for shared/fragment scopes to avoid floods.
```python
T.print(C_f, msg='accumulator:')
T.print(A_s, msg='A tile:')
T.print(C[0], msg='C[0] = ')
```
## 9. Where to Go Next
- Control flow details: see Programming Guides → Control Flow
- Memory topics: see Programming Guides → (removed cache/layout); basics are covered inline
- Autotuning tile sizes and mappings: Programming Guides → Autotuning
- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators
# Programming Guides Overview
This section provides a practical guide to writing high‑performance kernels with Tile Language (tile‑lang).
It mirrors the structure of a similar guide in another project and adapts it to tile‑lang concepts and APIs.
- Audience: Developers implementing custom GPU/CPU kernels with tile‑lang
- Prereqs: Basic Python, NumPy/Tensor concepts, and familiarity with GPU programming notions
- Scope: Language basics, control flow, instructions, autotuning, and type system
## What You’ll Learn
- How to structure kernels with TileLang’s core DSL constructs
- How to move data across global/shared/fragment and pipeline compute
- How to apply autotuning to tile sizes and schedules
- How to specify and work with dtypes in kernels
## Suggested Reading Order
1. Language Basics
2. Control Flow
3. Instructions
4. Autotuning
5. Type System
## Related Docs
- Tutorials: see existing guides in `tutorials/`
- Operators: examples in `deeplearning_operators/`
> NOTE: This is a draft scaffold. Fill in code snippets and benchmarks as APIs evolve.
# Type System
This page lists the data types supported by TileLang and how to specify them in
kernels. For full details and the authoritative list, see the API Reference
(`autoapi/tilelang/index`) and `tilelang.language.v2.dtypes`.
How to specify dtypes
- Use any of the following forms; TileLang normalizes them internally:
- String: `'float32'`, `'int8'`, `'bfloat16'`, ...
- TileLang dtype object: `T.float32`, `T.int8`, `T.bfloat16`, ...
- Framework dtype: `torch.float32`, `torch.int8`, `torch.bfloat16`, ...
Common scalar types
- Boolean: `bool`
- Signed integers: `int8`, `int16`, `int32`, `int64`
- Unsigned integers: `uint8`, `uint16`, `uint32`, `uint64`
- Floating‑point: `float16` (half), `bfloat16`, `float32`, `float64`
Float8 and low‑precision families
- Float8: `float8_e3m4`, `float8_e4m3`, `float8_e4m3b11fnuz`, `float8_e4m3fn`,
`float8_e4m3fnuz`, `float8_e5m2`, `float8_e5m2fnuz`, `float8_e8m0fnu`
- Float6: `float6_e2m3fn`, `float6_e3m2fn`
- Float4: `float4_e2m1fn`
Vectorized element types (SIMD packs)
- For many base types, vector‑packed variants are available by lane count:
`x2`, `x4`, `x8`, `x16`, `x32`, `x64`.
- Examples:
- Integers: `int8x2`, `int8x4`, ..., `int32x2`, `int32x4`, ...
- Unsigned: `uint8x2`, `uint8x4`, ...
- Floats: `float16x2`, `float16x4`, `float32x2`, `float32x4`, ...
- Float8/6/4 families also provide `x2/x4/x8/x16/x32/x64` where applicable,
e.g., `float8_e4m3x2`, `float8_e4m3x4`, `float6_e2m3fnx8`, `float4_e2m1fnx16`.
Notes
- Availability of certain low‑precision formats (float8/6/4) depends on target
architecture and backend support.
- Choose accumulation dtypes explicitly for mixed‑precision compute (e.g.,
GEMM with `float16` inputs and `float32` accumulators).
- The complete, up‑to‑date list is exposed in
`tilelang.language.v2.dtypes` and rendered in the API Reference.
fastapi
pydantic
sphinx
sphinx-reredirects
sphinx-tabs
sphinx-toolbox
sphinxcontrib-napoleon
sphinxcontrib_httpdomain
furo
uvicorn
myst-parser
sphinx-autoapi == 3.6.0
astroid < 4
cancelled
hsa
ist
LOD
nd
NotIn
offen
te
Auto-Tuning Techniques for Performance Optimization
===================================================
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/yyttt6">yyttt6</a>
</div>
## Overview
Auto-tuning a Tile Language program involves three main steps:
1. Implement the target program using Tile Language with reserved optimization parameters
2. ​Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations)
3. Parallel compile and benchmark candidate configurations to identify the best performance
## Matrix Multiplication Example
The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation.
### Step 1: Implement with Reserved Parameters
Users can implement matrix multiplication in Tile Language while reserving parameters for optimization:
```python
# Reserved parameters for optimization
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
# Matrix multiplication implementation
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
# ...existing code...
return main
```
### Step 2: Generate Candidate Configurations
Manually define configurations or use combinatorial generation:
```python
configs = [
{
"block_M": 128,
"block_N": 128,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"enable_rasteration": True
},
{
"block_M": 32,
"block_N": 32,
"block_K": 32,
"num_stages": 0,
"thread_num": 32,
"enable_rasteration": False
},
# ...additional configurations...
]
```
It can also be given by combinatorial traversal of different parameters
```python
import itertools
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5]
} for c in _configs
]
```
### Step 3: Compile and Benchmark
Configure JIT compilation and benchmarking settings:
```python
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
result = autotuner.run(warmup=3, rep=20)
out_c = result.kernel(a, b)
```
The result object contains optimized kernel implementation which can be used by users directly
## Using Carver to Auto-Generate Candidate Configurations
Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.
or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`):
```python
# Configure Matmul template
arch = CUDA("cuda")
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
# Generate top-k optimization hints (topk=10 recommended)
roller_hints = carve_template.recommend_hints(topk=10)
# Configure candidate parameters
for hint in roller_hints:
# ...existing code...
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
```
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment