`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations.
## Why Fences Are Needed
Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behavior.
## What the Pass Does
- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence).
- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization.
- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier.
The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted.
The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs.
## Coverage of Intrinsics
The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow.
## Usage
The pass is part of the default TileLang lowering pipeline. To apply it manually:
The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches.
## Extending the Pass
If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly.
This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance.
## Overview
A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined.
## When Does LetStmt Get Inlined?
The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met:
### 1. The value satisfies `CanInlineLetStmt`
The `CanInlineLetStmt` helper returns `true` when:
-**The value is a constant** (`is_const_number(op->value)` returns true)
-**The value is a variable** (`op->value.as<VarNode>()` returns a node)
-**The value is an integer expression without side effects**:
- The value has `int` dtype
- The side effect level is `kPure` or lower (no observable side effects)
```cpp
boolCanInlineLetStmt(constLetStmtNode*op){
if(is_const_number(op->value))
returntrue;
if(op->value.as<VarNode>())
returntrue;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be
### 2. The variable is NOT used in buffer definitions
Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields).
This protection exists because:
- Buffer definitions are not updated during the simplification pass
- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition
- This would cause compilation errors or incorrect behavior
The mutator checks this before dropping the binding:
-`stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects)
- However, `stride` is used in `buffer_a`'s `strides` field
- If we inline it, the buffer definition becomes `strides=[M*16, 1]`
- But the Buffer object's fields are not updated during simplification
- Later code accessing `buffer_a` would fail to find the `stride` variable
Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined.
## How Variables Are Collected
The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions:
```cpp
voidVisitBuffer(constBuffer&buf){
// Collect variables that should remain defined
VarUseDefAnalyzerusage(Array<Var>{});
usage(buf->data);
for(constauto&dim:buf->shape){
usage(dim);
}
for(constauto&dim:buf->strides){
usage(dim);
}
usage(buf->elem_offset);
// Track for use in LetStmtNode mutator
for(constauto&var:usage.undefined_){
used_in_buffer_def_.insert(var.get());
}
}
```
## Practical Example: Temporary Variable Issue
Consider this TileLang code:
```python
foriinT.Parallel(block_N):
idx=bx*block_N+i
tmp=T.max(A[idx],1)
B[idx]=tmp/2
A[idx]=tmp*2
```
In this case:
-`tmp` is an integer-like temporary variable
- It satisfies `CanInlineLetStmt` (pure int expression)
- It's **not** used in any buffer definition
- Therefore, `tmp`**will be inlined**
This means the IR becomes:
```python
foriinT.Parallel(block_N):
idx=bx*block_N+i
B[idx]=T.max(A[idx],1)/2
A[idx]=T.max(A[idx],1)*2
```
If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.
## Controlling Let Inlining via Pass Config
TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments.
If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it.
## Summary
The LetStmt inlining mechanism is a **conservative optimization** that:
1. Aggressively inlines simple, pure integer expressions to simplify the IR
2. Protects variables used in buffer definitions to avoid breaking buffer access
3. Helps reduce IR complexity and improve code generation
4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required
Understanding when inlining happens is crucial for:
- Debugging compilation issues
- Understanding generated code
- Writing efficient TileLang programs
- Identifying potential optimization opportunities or bugs
## Related Files
-`src/transform/simplify.cc`: Main Simplify implementation
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
## Why Host-Side Checks
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
## How To Inspect Host Source
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
```python
print(matmul_relu_kernel.get_host_source())
```
---
## What The Host Checks
### 1) Argument count and pointer kind
-`num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.
### 2) Tensor checks (per tensor, after nullability decision)
- Nullability
- If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
- If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
- Rank (`ndim`)
- Runtime `ndim` must equal the compile-time rank.
- Data type (`dtype`)
- Match the triple `(code, bits, lanes)` with tolerance:
-`bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
- For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
- Shape
- Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
- Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
- Strides
- If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
- Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
-`byte_offset`
- Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
- Device info
- Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
- When multiple tensors participate, assert that `device_id` matches across them.
- Data pointer
- Must be non-NULL when the tensor is required to be non-null by the nullability rule.
### 3) Scalar checks
-`T.int*` family: require integer; error: `Expect arg[i] to be int`.
-`T.bool`: require boolean; error: `Expect arg[i] to be boolean`.
---
## Shapes and Symbolic Equations: Linear Solving
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
```python
@T.prim_func
defmain(
A:T.Tensor((m,),dtype),
B:T.Tensor((m+n,),dtype),
C:T.Tensor((n*k,),dtype),
):
...
```
This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.
---
## Nullability Rules and Examples
Which tensors may be NULL?
- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
- Examples:
1) Must be non-NULL (used)
```python
@T.prim_func
defmain(A:T.Tensor((M,K),dtype)):
A[0]=1
```
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.
- Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`
- NULL data pointer
- Trigger: tensor required to be non-null has a NULL data pointer
- Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`
- Scalar type mismatch
- Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
- Error: `<kernel>: Expect arg[i] to be int/boolean`
---
## Troubleshooting Tips
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
---
## FAQ
- Can I disable the checks?
- Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
- Is the overhead noticeable?
- The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.
Fix: place all tensors on the same GPU (e.g., `cuda:0`).
### 9. NULL data pointer (advanced)
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.
Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
### 10. Scalar type mismatch (int / bool)
```python
importtilelang.languageasT
@T.prim_func
defscalar_check(x:T.int32,flag:T.bool()):
T.evaluate(0)
scalar_check(1.0,True)# x is float -> Expect arg[0] to be int
scalar_check(1,2.5)# flag is float -> Expect arg[1] to be boolean
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
## Implementation
First, let's review the core computation logic of traditional FlashAttention:
```python
# acc_s: [block_M, block_N]
# scores_max: [block_M]
# scores_scale: [block_M]
# acc_o: [block_M, dim]
foriinrange(loop_range):
acc_s=Q@K[i]
scores_max_prev=scores_max
scores_max=max(acc_s,dim=1)
scores_scale=exp(scores_max_prev-scores_max)
acc_o*=scores_scale
acc_s=exp(acc_s-scores_max)
acc_o=acc_s@V[i]
...
```
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
### Layout Inference
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
### Threadblock Swizzling
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
```python
T.use_swizzle(panel_size:int,order:str="row")
```
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
### Shared Memory Swizzling
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
### Pipeline
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
```python
T.pipelined(range:int,stage:int)
```
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
### Split-KV
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
## 🚀 On AMD MI300X Accelerators
Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms.
### Architectural Considerations and Optimization Strategies
Key implementation differences between Hopper and MI300X architectures include:
1.**Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations.
2.**Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes:
- Reducing software pipeline stages
- Register-based caching of Q matrices instead of shared memory utilization:
3.**Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64.
4.**Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code.
### Performance Evaluation
We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate:
<figcaptionstyle="text-align: center;">Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)</figcaption>
</figure>
Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) in most test cases, while significantly outperforming both Triton (1.98×) and PyTorch (3.76×) implementations. This performance is achieved through a concise 80-line Python implementation, demonstrating TileLang's efficiency and programmability advantages.
### Future Optimization Opportunities
1.**Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction.
2.**Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to:
- Reduce shared memory pressure
- Improve compute-to-memory access ratios
- Enhance parallelism through dimension-wise task distribution
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn:
1. How to implement an elementwise operator using TileLang.
2. How to compile operators with dynamic shapes.
3. How TileLang addresses boundary-related issues.
4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe.
Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the [Overview](../get_started/overview.md).
All logic for TileLang kernels must be implemented within the `T.Kernel(...)` scope. In this example, initializing `T.kernel(...)` requires specifying both the grid size and the number of threads per block. The returned value `bx` corresponds to `blockIdx.x` in CUDA. In the provided implementation, `T.Parallel` is used to process the data tile (of size `1 x threads`) assigned to the block for computation.
Those familiar with CUDA programming might wonder where `threadIdx` fits into this. Note that the code inside `T.Kernel` operates at the **block level**, not the **thread level**. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions.
The program can be compiled using the following code:
Launching the kernel is straightforward, just call it directly like a function:
```python
C=kernel(A,B)
```
The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: [example](https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py). The code for this kernel is provided below:
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
The resulting CUDA code for the kernel will include an additional `int N` parameter after the `bfloat16_t* __restrict__ A`, `bfloat16_t* __restrict__ B`, and `bfloat16_t* __restrict__ C` parameters.
### How TileLang addresses boundary-related issues.
TileLang automatically incorporates boundary-checking conditions; however, this comes at a cost. These boundary conditions may prevent TileLang from performing more advanced optimizations. I will introduce an example from the next section in advance. The corresponding code is also provided below, but note that it involves the associated CUDA code. Readers are encouraged to first review the next section before returning to this paragraph for a clearer understanding.
When compiling the example below, let's set `N` to 2047:
We can observe that TileLang did not apply optimizations such as vectorization or coalesced memory access. In fact, except for the tail group of data, all other threads could have executed more optimized code.
## Comparison of TileLang, CUDA, and CuTe
For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity.
Typically, those new to CUDA programming often write CUDA code in a style similar to this:
The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process.
The CUDA code generated by TileLang for the compiled kernel can be retrieved using the `kernel.get_kernel_source()` method. Below is the CUDA code produced for the vector addition example from Section 1:
```cu
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access.
While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elements—effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads.
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive.
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
In the example above, each thread is responsible for computing 8 elements. The `T.copy(...)` method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation.
We observed the emergence of two additional registers, `A_register` and `B_register`. However, during the actual computation, these registers are simply reassigned to `v_` and `v__1`, respectively.
To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below:
This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels.
---
**Reference:**
[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
Example code can be found at `examples/gemv/example_gemv.py`.
:::
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
## Triton Implementation
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
A simple Triton kernel for GEMV might look like this:
```python
@triton.jit
def_gemv_naive(
x_ptr,A_ptr,y_ptr,
N,K,
BLOCK_SIZE_K:tl.constexpr,
):
n=tl.program_id(0)
offs_k=tl.arange(0,BLOCK_SIZE_K)
mask=offs_k<K
a_ptrs=A_ptr+n*K+offs_k
a_vals=tl.load(a_ptrs,mask=mask,other=0.0)
x_vals=tl.load(x_ptr+offs_k,mask=mask,other=0.0)
dot=tl.sum(a_vals*x_vals,axis=0)
tl.store(y_ptr+n,dot)
```
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
## Naive Implementation in TileLang
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
## More Concurrency
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
### Customizing Parallelism in K Dimension
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd`
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
## Autotune
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically.
## Conclusion
### Benchmark Table on Hopper GPU
| Kernel Name | Latency |
|------------|------------|
| torch/cuBLAS | 0.00784 ms |
| Triton | 0.00773 ms |
| naive_gemv | 0.16607 ms |
| splitk_gemv | 0.02419 ms |
| splitk_gemv_vectorized | 0.00809 ms |
| splitk_gemv_vectorized_tvm | 0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives.
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
***Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
***Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
***Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
## Why Another GPU DSL?
TileLang emerged from the need for a DSL that:
1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed.
2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.).
3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning.
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
## Matrix Multiplication Example
```{figure} ../_static/img/MatmulExample.png
:alt: Matmul Example
:align: center
```
### Basic Structure
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
***`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.).
***`T.alloc_shared(...)`** to allocate GPU shared memory.
***`T.alloc_fragment(...)`** to allocate a register fragment for accumulation.
***`T.Pipelined(...)`** to express software pipelining across the K dimension.
***`T.Parallel(...)`** to parallelize data copy loops.
***`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
-`T.alloc_shared` allocates shared memory across the entire thread block.
-`T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
-`T.Parallel` marks the loop for thread-level parallelization.
- The compiler will map these loops to the available threads in the block.
5.**Tile-Level GEMM**:
```python
T.gemm(A_shared,B_shared,C_local)
```
- A single call that performs a tile-level matrix multiplication using the specified buffers.
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
6.**Copying Back Results**:
```python
T.copy(C_local,C[by*block_M,bx*block_N])
```
- After computation, data in the local register fragment is written back to global memory.
## Comparison with Other DSLs
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
- ~1.1x speedup over cuBLAS on RTX 4090
- ~0.97x on A100 (on par with cuBLAS)
- ~1.0x on H100
- ~1.04x on MI300X
- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware.
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
## Conclusion
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
1. We allocated shared memory and register fragments.
2. We pipelined the loading and computation along the K dimension.
3. We used parallel copying to efficiently load tiles from global memory.
4. We invoked `T.gemm` to dispatch a tile-level matrix multiply.
5. We verified correctness against PyTorch and examined performance.
By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs.
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
This document is still **experimental** and may be incomplete.
This feature is still **experimental** and need further optimization.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
Example code can be found at `examples/gemm_sp`.
:::
## Structured sparsity in the NVIDIA Ampere architecture
Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
:::{warning}
This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp` with CUTLASS's compressor
:::{warning}
It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
:::
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
Check comments in below kernel code for required modification.
```python
defmatmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit="8"inin_dtype
metadata_dtype='int32'ifis_8_bitelse'int16'
E_factor=SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]# Calculate shape for given datatypes
T.gemm_sp(A_shared,E_shared,B_shared,C_frag,trans_A,trans_B)# Call gemm_sp with non-zero values and metadata
T.copy(C_frag,C[by*block_M,bx*block_N])
returnmain
```
Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
## `T.gemm_sp_v2` with a custom compressor
To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
However, fixing a specific layout introduces several potential issues:
1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment.
First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands:
Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process.
> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies).
If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation.
```bash
pip install-e.-v
```
> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below.
(working-from-source-via-pythonpath)=
### Working from Source via `PYTHONPATH` (Recommended for Developers)
If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first:
```bash
mkdir-p build
cd build
cmake .. -DUSE_CUDA=ON
make -j
```
We also recommend using `ninja` to speed up compilation:
```bash
cmake .. -DUSE_CUDA=ON -G Ninja
ninja
```
Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example:
Some useful CMake options you can toggle while configuring:
-`-DUSE_CUDA=ON|OFF` builds against NVIDIA CUDA (default ON when CUDA headers are found).
-`-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs.
-`-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`.
(using-existing-tvm)=
### Building with Customized TVM Path
If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang:
```bash
TVM_ROOT=<your-tvm-repo> pip install.-v
```
> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly.
(install-using-docker)=
## Install Using Docker
For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems.
**Prerequisites:**
- Docker installed on your system
- NVIDIA Docker runtime or GPU is not necessary for building tilelang, you can build on a host without GPU and use that built image on other machine.
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/
```
> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.
## Install Configs
### Build-time environment variables
`USE_CUDA`: If to enable CUDA support, default: `ON` on Linux, set to `OFF` to build a CPU version. By default, we'll use `/usr/local/cuda` for building tilelang. Set `CUDAToolkit_ROOT` to use different cuda toolkit.
`USE_ROCM`: If to enable ROCm support, default: `OFF`. If your ROCm SDK does not located in `/opt/rocm`, set `USE_ROCM=<rocm_sdk>` to enable build ROCm against custom sdk path.
`USE_METAL`: If to enable Metal support, default: `ON` on Darwin.
`TVM_ROOT`: TVM source root to use.
`NO_VERSION_LABEL` and `NO_TOOLCHAIN_VERSION`:
When building tilelang, we'll try to embed SDK and version information into package version as below,
where local version label could look like `<sdk>.git<git_hash>`. Set `NO_VERSION_LABEL=ON` to disable this behavior.
```
$ python -mbuild -w
...
Successfully built tilelang-0.1.6.post1+cu116.git0d4a74be-cp38-abi3-linux_x86_64.whl
```
where `<sdk>={cuda,rocm,metal}`. Specifically, when `<sdk>=cuda` and `CUDA_VERSION` is provided via env,
`<sdk>=cu<cuda_major><cuda_minor>`, similar with this part in pytorch.
Set `NO_TOOLCHAIN_VERSION=ON` to disable this.
### Run-time environment variables
Please refer to the `env.py` file for a full list of supported run-time environment variables.
## Other Tips
### IDE Configs
Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir.
VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration.
### Compile Cache
The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found.
### Repairing Wheels
If you plan to use your wheel in other environment,
it's recommended to use auditwheel (on Linux) or delocate (on Darwin)
to repair them.
(faster-rebuild-for-developers)=
### Faster Rebuild for Developers
`pip install` introduces extra [un]packaging and takes ~30 sec to complete,
even if no source change.
Developers who needs to recompile frequently could use:
```bash
pip install-r requirements-dev.txt
# For first time compilation
pip install-e.-v--no-build-isolation
# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly.
mkdir build
cd build
cmake .. -G Ninja
ninja
# Rebuild when you change the cpp code
cd build; ninja
```
When running in editable/developer mode,
you'll see logs like below:
```console
$python -c'import tilelang'
2025-10-14 11:11:29 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /Users/yyc/repo/tilelang/build
If you encounter the error ModuleNotFoundError: No module named 'tvm_ffi', it means the TVM foreign function interface package was not installed. This often happens if the submodules were built manually. Fix it by running:
```
# Navigate to the tvm_ffi directory
cd 3rdparty/tvm/3rdparty/tvm_ffi
# Install the package in editable mode
pip install .
# Return to the project root
cd ../../../..
```
### DTK Path Configuration
If you encounter errors related to DTK path detection (e.g., hipcc not found or failure to retrieve GPU architecture), you may need to manually specify the DTK installation path in the source code.
Locate the file tilelang/contrib/rocm.py and modify the default value of the rocm_path parameter in the get_rocm_arch function (around line 231):
The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfaces—targeted at **Beginner**, **Developer**, and **Expert** users—that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
## Programming Interfaces
1.**Beginner Level (Hardware-Unaware)**
- Intended for users who need to write code that is independent of specific hardware details.
- The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations.
-*Note:* This interface is not yet fully implemented.
2.**Developer Level (Hardware-Aware with Tile Library)**
- Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations.
- Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures.
- Users at this level can leverage these ready-made primitives without diving into low-level threading details.
3.**Expert Level (Hardware-Aware with Thread Primitives)**
- For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing).
- Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels.
- This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures.
## Compilation Flow
1.**Tile Program**
A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives.
2.**Tile Program with Tile Library**
When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations.
3.**Tile Program with Thread Primitives**
Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage.
4.**IRModule**
After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details.
5.**Source Code Generation (C/CUDA/HIP/LLVM/…)**
From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD).
6.**Hardware-Specific Executable/Runtime**
Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures.
## Tile-based Programming Model
[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``,
illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining,
and operator calls to manage data movement and computation with fine-grained control.
In particular, this snippet ([Figure 2](#fig-overview-gemm)(a)) demonstrates how multi-level tiling
leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization
and reduce latency.
Overall, [Figure 2](#fig-overview-gemm)(b) showcases how the Python-like syntax of ``TileLang``
allows developers to reason about performance-critical optimizations within a user-friendly programming model.
```{figure} ../_static/img/MatmulExample.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
:name: fig-overview-gemm
Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``.
```
### Tile declarations
At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block.
### Explicit Hardware Memory Allocation
A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular:
-`T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance.
-`T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections.
Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections.