Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
<svg width="2268" height="537" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" overflow="hidden"><defs><clipPath id="clip0"><rect x="0" y="1440" width="2268" height="537"/></clipPath><image width="401" height="463" xlink:href="" preserveAspectRatio="none" id="img1"></image><clipPath id="clip2"><rect x="1548" y="1469" width="401" height="463"/></clipPath><linearGradient x1="5.66719" y1="56.5386" x2="5.66719" y2="-222.487" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill3"><stop offset="0" stop-color="#E73768"/><stop offset="0.5" stop-color="#FFFFFF"/><stop offset="1" stop-color="#69E0F9"/></linearGradient></defs><g clip-path="url(#clip0)" transform="matrix(1 0 0 1 0 -1440)"><path d="M0 0 2266.88 0 2266.88 1666.82 0 1666.82Z" fill="#0A0619" transform="matrix(1.0005 0 0 1 0 867.18)"/><g clip-path="url(#clip2)"><use width="100%" height="100%" xlink:href="#img1" transform="translate(1548 1469)"></use></g><path d="M155.214-196.218 5.66719-196.218 5.66719-147.014 52.2715-147.014 52.2715 0 108.277 0 108.277-147.014 155.214-147.014ZM203.019-168.482C227.888-168.482 232.088-171.016 232.088-195.952 232.088-220.487 227.888-222.487 203.019-222.487 178.417-222.487 173.883-220.487 173.883-195.952 173.883-171.016 178.417-168.482 203.019-168.482ZM176.15 0 229.821 0 229.821-152.681 176.15-152.681ZM313.829 1.66682C320.296 1.66682 326.564 1.13344 332.764 0L332.764-40.4038C329.631-39.6037 328.764-39.8704 326.83-39.8704 318.03-39.8704 315.229-43.5374 315.229-56.2719L315.229-214.887 261.491-214.887 261.491-43.5374C261.491-12.7345 271.692 1.66682 313.829 1.66682ZM493.646-86.208C493.646-132.879 478.911-155.481 424.64-155.481 374.035-155.481 344.632-139.08 344.632-76.3404 344.632-13.6013 374.035 2.80026 422.106 2.80026 452.309 2.80026 476.644-2.80026 485.979-9.3342L485.979-48.9379C476.911-43.5374 453.442-37.8702 432.84-37.8702 413.305-37.8702 401.171-43.2707 398.037-56.5386L491.912-62.2058C492.779-64.4727 493.646-73.8069 493.646-86.208ZM397.77-93.0087C398.904-111.944 406.838-116.211 424.906-116.211 441.908-116.211 446.108-108.277 446.108-96.6757ZM577.32-48.3378 577.32-196.218 521.049-196.218 521.049 0 649.927 0 649.927-48.3378ZM739.269-155.214C728.268-155.214 711.866-154.614 700.265-152.681L700.265-107.743C710.199-109.41 720.6-110.277 731.668-110.277 754.003-110.277 759.671-108.01 760.471-92.4753L729.135-92.4753C684.73-92.4753 664.662-79.4741 664.662-44.4041 664.662-11.6011 684.73 2.80026 716.667 2.80026 743.536 2.80026 756.537-6.53394 761.071-14.1346L765.271 0 813.942 0 813.942-103.743C813.942-139.346 792.74-155.214 739.269-155.214ZM733.668-37.0034C722.601-37.0034 716.667-39.0036 716.667-46.9377 716.667-56.0052 722.067-58.5388 739.002-58.5388L760.471-58.5388 760.471-46.671C756.27-41.2705 746.936-37.0034 733.668-37.0034ZM949.888-155.481C925.019-155.481 910.351-147.547 902.684-137.146L902.684-152.681 849.012-152.681 849.012 0 902.684 0 902.684-100.676C904.951-108.277 910.885-113.077 924.486-113.077 941.421-113.077 946.221-109.944 946.221-91.0085L946.221 0 999.96 0 999.96-103.21C999.96-140.213 985.825-155.481 949.888-155.481ZM1129.97-152.681 1129.97-139.346C1124.04-150.681 1112.5-155.481 1086.43-155.481 1038.1-155.481 1025.7-119.611 1025.7-77.7406 1025.7-31.4029 1038.1 0 1086.43 0 1112.17 0 1124.04-6.53394 1129.97-17.535L1129.97-13.6013C1129.97 10.7343 1117.84 16.6682 1085.3 16.6682 1072.03 16.6682 1054.76 14.4013 1042.63 11.3344L1042.63 52.8716C1057.03 55.1385 1076.57 56.5386 1090.97 56.5386 1163.38 56.5386 1183.18 29.136 1183.18-12.7345L1183.18-152.681ZM1105.1-37.0034C1083.37-37.0034 1079.97-55.6719 1079.97-77.7406 1079.97-98.4092 1083.37-117.011 1105.1-117.011 1130.84-117.011 1132.57-102.343 1132.57-77.7406 1132.57-51.7382 1130.84-37.0034 1105.1-37.0034Z" fill="url(#fill3)" transform="matrix(1.0005 0 0 1 313.501 1774.2)"/><path d="M0.533383-56.0052 0.533383-46.8044 17.4683-46.8044 17.4683 0 28.2693 0 28.2693-46.8044 45.0709-46.8044 45.0709-56.0052ZM51.0714 0 61.8724 0 61.8724-56.0052 51.0714-56.0052ZM72.8068 0 113.544 0 113.544-10.5343 83.6745-10.5343 83.6745-56.0052 72.8068-56.0052ZM120.078 0 161.548 0 161.548-10.5343 130.946-10.5343 130.946-23.6689 158.948-23.6689 158.948-33.6031 130.946-33.6031 130.946-45.5376 161.548-45.5376 161.548-56.0052 120.078-56.0052ZM192.285 0 232.955 0 232.955-10.5343 203.152-10.5343 203.152-56.0052 192.285-56.0052ZM268.492-56.0052 254.824-56.0052 234.755 0 246.423 0 249.957-10.401 274.292-10.401 278.026 0 290.494 0ZM252.757-18.8018 261.291-44.4041 262.091-44.4041 271.292-18.8018ZM296.361 0 306.829 0 306.829-38.0702 307.229-38.0702 334.164 0 344.765 0 344.765-56.0052 334.431-56.0052 334.431-17.4683 334.031-17.4683 306.962-56.0052 296.361-56.0052ZM400.571-8.40078 400.171 0 410.305 0C410.305-0.400037 410.305-29.4694 410.305-29.8694L380.502-29.8694 380.502-22.002 400.371-22.002C399.571-15.2014 392.57-8.13409 382.436-8.13409 370.034-8.13409 363.767-16.8016 363.767-28.0026 363.767-38.8703 371.301-47.2711 382.502-47.2711 391.103-47.2711 396.57-43.2707 398.704-36.8701L410.105-36.8701C407.371-50.8714 397.704-56.8053 382.436-56.8053 364.567-56.8053 353.3-45.6042 353.3-28.0026 353.3-10.4676 363.501 0.800074 381.302 0.800074 394.037 0.800074 398.904-5.33383 400.171-8.40078ZM455.509-21.8687C455.509-13.868 450.175-9.73424 442.508-9.73424 434.574-9.73424 430.107-13.868 430.107-21.8687L430.107-56.0052 419.172-56.0052 419.172-20.2686C419.172-5.93389 430.04 0.800074 442.575 0.800074 455.576 0.800074 466.11-5.93389 466.11-20.2686L466.11-56.0052 455.509-56.0052ZM503.18-56.0052 489.512-56.0052 469.377 0 481.111 0 484.578-10.401 508.914-10.401 512.714 0 525.182 0ZM487.379-18.8018 495.979-44.4041 496.78-44.4041 505.98-18.8018ZM573.387-8.40078 572.987 0 583.121 0C583.121-0.400037 583.121-29.4694 583.121-29.8694L553.318-29.8694 553.318-22.002 573.187-22.002C572.387-15.2014 565.386-8.13409 555.252-8.13409 542.851-8.13409 536.583-16.8016 536.583-28.0026 536.583-38.8703 544.117-47.2711 555.318-47.2711 563.919-47.2711 569.386-43.2707 571.52-36.8701L582.921-36.8701C580.187-50.8714 570.52-56.8053 555.252-56.8053 537.383-56.8053 526.116-45.6042 526.116-28.0026 526.116-10.4676 536.317 0.800074 554.118 0.800074 566.853 0.800074 571.72-5.33383 572.987-8.40078ZM592.455 0 633.926 0 633.926-10.5343 603.323-10.5343 603.323-23.6689 631.325-23.6689 631.325-33.6031 603.323-33.6031 603.323-45.5376 633.926-45.5376 633.926-56.0052 592.455-56.0052Z" fill="#FFFFFF" transform="matrix(1.0005 0 0 1 590.804 1927.55)"/></g></svg>
\ No newline at end of file
# InjectFenceProxy Pass
`tl.InjectFenceProxy` is a TIR-level transform that keeps the GPU proxy state consistent on NVIDIA Hopper (SM90+) by inserting `fence.proxy.async` instructions when control flow switches from generic memory operations to asynchronous proxy operations.
## Why Fences Are Needed
Hopper separates memory instructions into generic and asynchronous proxy paths. When an asynchronous instruction (for example, `cp.async` or `tma.load`) issues after generic traffic (like `ldmatrix` or plain buffer stores), the hardware requires a `fence.proxy.async` to guarantee ordering. Missing fences can lead to race conditions or undefined behavior.
## What the Pass Does
- Walks every statement in the `PrimFunc`, tracking whether it behaves as a **generic**, **async**, or **neutral** proxy (neutral statements reset the state, such as an explicit fence).
- Automatically lowers `tma_store` intrinsics into the required `arrive`/`wait` handshake so that TMA stores participate correctly in synchronization.
- Injects an explicit `fence.proxy.async` whenever a generic statement is followed by an async statement without an intervening neutral barrier.
The pass is conservative: unknown extern calls are treated as async so that the fence is inserted rather than accidentally omitted.
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
└──────────────────────────────┘
```
The proxy tracker scans the sequence from left to right. The moment it detects a transition from generic to async (between the store and `cp.async` above), it synthesizes a `fence.proxy.async` to reset the hardware proxy state before the async path runs.
## Coverage of Intrinsics
The tracker understands the TileLang intrinsics for TMA load/store, shared-memory MMA (`wgmma`), and TVM/PTX async copy intrinsics (`cp.async` variants). Generic operations currently include `ldmatrix`, `stmatrix`, and descriptor initialization. Other IR nodes (loops, blocks, attributes) receive a proxy kind derived from their bodies so that the analysis survives structured control flow.
## Usage
The pass is part of the default TileLang lowering pipeline. To apply it manually:
```python
from tilelang import tl
from tvm import IRModule
mod = IRModule({"main": prim_func})
with tvm.transform.PassContext():
mod = tl.transform.InjectFenceProxy()(mod)
```
## End-to-End Example
Before the pass:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
After `tl.transform.InjectFenceProxy`:
```python
@T.prim_func
def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc.data,
T.int32(0),
desc.data,
T.int32(0),
smem.data,
T.int32(0),
T.bool(True),
1,
1,
)
```
The only change is the `fence_proxy_async` between the generic descriptor setup / shared-memory write and the async `wgmma`. In larger kernels the pass performs the same operation across nested blocks, loops, and conditional branches.
## Extending the Pass
If you introduce a new intrinsic that behaves like an async proxy, add it to `IsAsyncIntrinsic` in `src/transform/inject_fence_proxy.cc`. Likewise, extend `IsKnownGeneric` for additional generic operations. When adding new neutral barriers, make sure they set the proxy kind to `kNeutral` so the state resets correctly.
# LetStmt Inlining in TileLang
This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance.
## Overview
A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined.
## When Does LetStmt Get Inlined?
The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met:
### 1. The value satisfies `CanInlineLetStmt`
The `CanInlineLetStmt` helper returns `true` when:
- **The value is a constant** (`is_const_number(op->value)` returns true)
- **The value is a variable** (`op->value.as<VarNode>()` returns a node)
- **The value is an integer expression without side effects**:
- The value has `int` dtype
- The side effect level is `kPure` or lower (no observable side effects)
```cpp
bool CanInlineLetStmt(const LetStmtNode *op) {
if (is_const_number(op->value))
return true;
if (op->value.as<VarNode>())
return true;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be
// index).
if (!op->value.dtype().is_int())
return false;
return SideEffect(op->value) <= CallEffectKind::kPure;
}
```
### 2. The variable is NOT used in buffer definitions
Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields).
This protection exists because:
- Buffer definitions are not updated during the simplification pass
- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition
- This would cause compilation errors or incorrect behavior
The mutator checks this before dropping the binding:
```cpp
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
if (can_inline && !used_in_buffer_def) {
return body; // Inline: remove LetStmt and return body directly
}
```
## Example: Why Buffer Definition Variables Are Protected
Consider this code:
```python
let stride = M * 16
let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1])
buffer_a[i, j] = ...
```
- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects)
- However, `stride` is used in `buffer_a`'s `strides` field
- If we inline it, the buffer definition becomes `strides=[M*16, 1]`
- But the Buffer object's fields are not updated during simplification
- Later code accessing `buffer_a` would fail to find the `stride` variable
Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined.
## How Variables Are Collected
The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions:
```cpp
void VisitBuffer(const Buffer &buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto &dim : buf->shape) {
usage(dim);
}
for (const auto &dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
// Track for use in LetStmtNode mutator
for (const auto &var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
```
## Practical Example: Temporary Variable Issue
Consider this TileLang code:
```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
tmp = T.max(A[idx], 1)
B[idx] = tmp / 2
A[idx] = tmp * 2
```
In this case:
- `tmp` is an integer-like temporary variable
- It satisfies `CanInlineLetStmt` (pure int expression)
- It's **not** used in any buffer definition
- Therefore, `tmp` **will be inlined**
This means the IR becomes:
```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
B[idx] = T.max(A[idx], 1) / 2
A[idx] = T.max(A[idx], 1) * 2
```
If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.
## Controlling Let Inlining via Pass Config
TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments.
```python
from tilelang import transform
from tilelang.engine.phase import LowerAndLegalize
with transform.PassContext(
config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True}
):
lowered_mod = LowerAndLegalize(input_mod, target)
```
If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it.
## Summary
The LetStmt inlining mechanism is a **conservative optimization** that:
1. Aggressively inlines simple, pure integer expressions to simplify the IR
2. Protects variables used in buffer definitions to avoid breaking buffer access
3. Helps reduce IR complexity and improve code generation
4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required
Understanding when inlining happens is crucial for:
- Debugging compilation issues
- Understanding generated code
- Writing efficient TileLang programs
- Identifying potential optimization opportunities or bugs
## Related Files
- `src/transform/simplify.cc`: Main Simplify implementation
- `src/transform/frontend_legalize.cc`: Standalone LetInline pass
- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining
- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass
# General information about the project.
project = "Tile Language <br>"
author = "Tile Lang Contributors"
copyright = f"2025-2025, {author}"
# Version information.
with open("../VERSION") as f:
version = f.read().strip()
release = version
extensions = [
"sphinx_tabs.tabs",
"sphinx_toolbox.collapse",
"sphinxcontrib.httpdomain",
"sphinx.ext.napoleon",
"sphinx.ext.intersphinx",
"sphinx_reredirects",
"sphinx.ext.mathjax",
"myst_parser",
"autoapi.extension",
]
autoapi_type = 'python'
autoapi_dirs = ['../tilelang']
autoapi_options = [
'members',
'undoc-members',
'show-inheritance',
'show-module-summary',
'special-members',
]
autoapi_keep_files = False # Useful for debugging the generated rst files
autoapi_generate_api_docs = True
autodoc_typehints = 'description'
autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"]
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
myst_enable_extensions = [
"colon_fence",
"deflist",
]
redirects = {"get_started/try_out": "../index.html#getting-started"}
language = "en"
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "README.md", "**/*libinfo*", "**/*version*"]
pygments_style = "sphinx"
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
html_theme = "furo"
templates_path = []
html_static_path = ["_static"]
footer_copyright = "© 2025-2025 Tile Language"
footer_note = " "
html_theme_options = {
"light_logo": "img/logo-row.svg",
"dark_logo": "img/logo-row.svg",
}
header_links = [
("Home", "https://github.com/tile-ai/tilelang"),
("Github", "https://github.com/tile-ai/tilelang"),
]
html_context = {
"footer_copyright": footer_copyright,
"footer_note": footer_note,
"header_links": header_links,
"display_github": True,
"github_user": "tile-ai",
"github_repo": "tilelang",
"github_version": "main/docs/",
"theme_vcs_pageview_mode": "edit",
}
# 🚀 Write High Performance FlashMLA with TileLang on Hopper
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/chengyupku">Yu Cheng</a>
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
```{figure} ../_static/img/mla_hopper/bs64_float16.png
:width: 50%
:alt: Overview
:align: center
Figure 1: Performance under batch size=64
```
```{figure} ../_static/img/mla_hopper/bs128_float16.png
:width: 50%
:alt: Overview
:align: center
Figure 2: Performance under batch size=128
```
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
## Implementation
First, let's review the core computation logic of traditional FlashAttention:
```python
# acc_s: [block_M, block_N]
# scores_max: [block_M]
# scores_scale: [block_M]
# acc_o: [block_M, dim]
for i in range(loop_range):
acc_s = Q @ K[i]
scores_max_prev = scores_max
scores_max = max(acc_s, dim=1)
scores_scale = exp(scores_max_prev - scores_max)
acc_o *= scores_scale
acc_s = exp(acc_s - scores_max)
acc_o = acc_s @ V[i]
...
```
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
### Layout Inference
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
```{figure} ../_static/img/mla_hopper/qk_layout.jpg
:width: 50%
:alt: Overview
:align: center
Figure 3: Buffer shapes in Q @ K
```
```{figure} ../_static/img/mla_hopper/pv_layout.jpg
:width: 50%
:alt: Overview
:align: center
Figure 4: Buffer shapes in acc_s @ V
```
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
### Threadblock Swizzling
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
```python
T.use_swizzle(panel_size: int, order: str = "row")
```
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
### Shared Memory Swizzling
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
```python
T.annotate_layout({
S_shared: TileLang.layout.make_swizzled_layout(S_shared),
})
```
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
### Pipeline
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
```python
T.pipelined(range: int, stage: int)
```
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
### Split-KV
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
## 🚀 On AMD MI300X Accelerators
Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms.
### Architectural Considerations and Optimization Strategies
Key implementation differences between Hopper and MI300X architectures include:
1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations.
2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes:
- Reducing software pipeline stages
- Register-based caching of Q matrices instead of shared memory utilization:
```python
# Original shared memory allocation
Q_shared = T.alloc_shared([block_H, dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
# Optimized register allocation
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
```
3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64.
4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code.
### Performance Evaluation
We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate:
<figure style="text-align: center">
<a href="../figures/flashmla-amd.png">
<img src="../figures/flashmla-amd.png" alt="AMD FlashMLA Performance Comparison">
</a>
<figcaption style="text-align: center;">Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)</figcaption>
</figure>
Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) in most test cases, while significantly outperforming both Triton (1.98×) and PyTorch (3.76×) implementations. This performance is achieved through a concise 80-line Python implementation, demonstrating TileLang's efficiency and programmability advantages.
### Future Optimization Opportunities
1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction.
2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to:
- Reduce shared memory pressure
- Improve compute-to-memory access ratios
- Enhance parallelism through dimension-wise task distribution
# ElementWise Operators
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/chenghuaWang">Chenghua Wang</a>
</div>
:::{warning}
:class: myclass1 myclass2
:name: a-tip-reference
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
Elementwise operators are widely used in deep learning and often serve as the first example encountered by those beginning to explore parallel programming. This tutorial will analyze several implementations of the elementwise addition operator using TileLang and compare them with the corresponding CUDA implementation. By the end of this tutorial, you will learn:
1. How to implement an elementwise operator using TileLang.
2. How to compile operators with dynamic shapes.
3. How TileLang addresses boundary-related issues.
4. The similarities and differences between operators implemented in TileLang and those implemented in CUDA/CuTe.
Please note that this tutorial does not delve deeply into the design principles of TileLang. For a broader understanding of TileLang, we recommend consulting the [Overview](../get_started/overview.md).
## Elementwise add in TileLang
```python
def elementwise_add(N, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x):
# vector add.
for i in T.Parallel(threads):
C[b_x * threads + i] = A[b_x * threads + i] + B[b_x * threads + i]
return main
```
All logic for TileLang kernels must be implemented within the `T.Kernel(...)` scope. In this example, initializing `T.kernel(...)` requires specifying both the grid size and the number of threads per block. The returned value `bx` corresponds to `blockIdx.x` in CUDA. In the provided implementation, `T.Parallel` is used to process the data tile (of size `1 x threads`) assigned to the block for computation.
Those familiar with CUDA programming might wonder where `threadIdx` fits into this. Note that the code inside `T.Kernel` operates at the **block level**, not the **thread level**. In this example, your focus is solely on defining the block-level logic. During compilation, TileLang automatically maps computations to the corresponding threads and applies further optimizations. The optimized code generated by TileLang may closely align with carefully handcrafted computational logic, as demonstrated in Section 2 with a concrete example. While TileLang also supports thread-level programming semantics, this will be covered in subsequent discussions.
The program can be compiled using the following code:
```python
program = elementwise_add(1024, threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
Launching the kernel is straightforward, just call it directly like a function:
```python
C = kernel(A, B)
```
The vector add operation can also be extended to two-dimensional cases, where both implementations demonstrate comparable efficiency in practice. Below is an example from the test section that readers can refer to: [example](https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py). The code for this kernel is provided below:
```python
import tilelang.language as T
def elementwise_add(
M,
N,
block_M,
block_N,
in_dtype,
out_dtype,
threads,
):
@T.prim_func
def main(
A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]
return main
```
### How to compile operators with dynamic shapes?
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
The resulting CUDA code for the kernel will include an additional `int N` parameter after the `bfloat16_t* __restrict__ A`, `bfloat16_t* __restrict__ B`, and `bfloat16_t* __restrict__ C` parameters.
### How TileLang addresses boundary-related issues.
TileLang automatically incorporates boundary-checking conditions; however, this comes at a cost. These boundary conditions may prevent TileLang from performing more advanced optimizations. I will introduce an example from the next section in advance. The corresponding code is also provided below, but note that it involves the associated CUDA code. Readers are encouraged to first review the next section before returning to this paragraph for a clearer understanding.
When compiling the example below, let's set `N` to 2047:
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
offsets = (b_x * threads + i) * num_per_thread
C[offsets + j] = A[offsets + j] + B[offsets + j]
return main
```
TileLang will generate the following CUDA code:
```c++
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
#pragma unroll
for (int i = 0; i < 8; ++i) {
if (((i * 256) + ((int)threadIdx.x)) < 2047) {
C[((i * 256) + ((int)threadIdx.x))] = (A[((i * 256) + ((int)threadIdx.x))] + B[((i * 256) + ((int)threadIdx.x))]);
}
}
}
```
We can observe that TileLang did not apply optimizations such as vectorization or coalesced memory access. In fact, except for the tail group of data, all other threads could have executed more optimized code.
## Comparison of TileLang, CUDA, and CuTe
For the subsequent examples, this tutorial will use the vector add operation for simplicity and brevity.
Typically, those new to CUDA programming often write CUDA code in a style similar to this:
```c++
// vector add
__global__ void elementwise_add(float* a, float* b, float* c, int N) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < N) {
c[idx] = a[idx] + b[idx];
}
}
```
The code above assigns each thread to compute a single element, which is evidently inefficient since common acceleration techniques like coalesced memory access and vectorization are not utilized. However, TileLang code written with similar logic (e.g., loop-based traversal) can be optimized by the compiler into highly efficient implementations, making it more accessible for beginners. Additionally, the final generated code from the compiler remains observable, providing transparency into the optimization process.
The CUDA code generated by TileLang for the compiled kernel can be retrieved using the `kernel.get_kernel_source()` method. Below is the CUDA code produced for the vector addition example from Section 1:
```cu
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
if (((int)threadIdx.x) < 32) {
uint4 __1;
uint4 v_ = *(uint4*)(A + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
uint4 v__1 = *(uint4*)(B + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8)));
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(C + ((((int)blockIdx.x) * 256) + (((int)threadIdx.x) * 8))) = __1;
}
}
```
In the code above, TileLang not only automatically maps block-level parallelism to threads but also applies optimizations such as vectorization and coalesced memory access.
While TileLang incorporates various optimizations for the aforementioned case, its behavior may sometimes appear counterintuitive. For example, when targeting 256 threads for task processing, applying vectorization can result in each thread computing 8 data elements—effectively utilizing only 32 active threads. Interestingly, the kernel launch configuration still retains the original allocation of 256 threads.
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
offsets = (b_x * threads + i) * num_per_thread
C[offsets + j] = A[offsets + j] + B[offsets + j]
return main
```
The corresponding CUDA code generated for the above example is presented below:
```c++
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
uint4 __1;
uint4 v_ = *(uint4*)(A + (((int)threadIdx.x) * 8));
uint4 v__1 = *(uint4*)(B + (((int)threadIdx.x) * 8));
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(C + (((int)threadIdx.x) * 8)) = __1;
}
```
Aha, this CUDA code aligns closely with conventional programming practices, making it more familiar and intuitive.
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x):
A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
C_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
s_start = b_x * threads * NUM_ELE_PER_THREAD
s_end = (b_x + 1) * threads * NUM_ELE_PER_THREAD
# LDG. 128
T.copy(
A[s_start:s_end],
A_register,
)
T.copy(
B[s_start:s_end],
B_register,
)
# vector add.
for tid, i in T.Parallel(threads, NUM_ELE_PER_THREAD):
C_register[tid * NUM_ELE_PER_THREAD + i] = (
A_register[tid * NUM_ELE_PER_THREAD + i] +
B_register[tid * NUM_ELE_PER_THREAD + i])
# STG. 128
T.copy(
C_register,
C[s_start:s_end],
)
return main
```
In the example above, each thread is responsible for computing 8 elements. The `T.copy(...)` method functions at the block level, and TileLang automatically maps data movement operations to individual threads. This design may resonate more intuitively with CUDA developers. Let us now analyze the CUDA code generated from this implementation.
```c++
// N is set to 8192 * 8192 when compiling
extern "C" __global__ void __launch_bounds__(256) main_kernel(bfloat16_t* __restrict__ A, bfloat16_t* __restrict__ B, bfloat16_t* __restrict__ C) {
bfloat16_t A_register[8];
bfloat16_t B_register[8];
*(uint4*)(A_register + 0) = *(uint4*)(A + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
*(uint4*)(B_register + 0) = *(uint4*)(B + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8)));
uint4 __1;
uint4 v_ = *(uint4*)(A_register + 0);
uint4 v__1 = *(uint4*)(B_register + 0);
((nv_bfloat162*)(&(__1.x)))->x = (((nv_bfloat162*)(&(v_.x)))->x+((nv_bfloat162*)(&(v__1.x)))->x);
((nv_bfloat162*)(&(__1.x)))->y = (((nv_bfloat162*)(&(v_.x)))->y+((nv_bfloat162*)(&(v__1.x)))->y);
((nv_bfloat162*)(&(__1.y)))->x = (((nv_bfloat162*)(&(v_.y)))->x+((nv_bfloat162*)(&(v__1.y)))->x);
((nv_bfloat162*)(&(__1.y)))->y = (((nv_bfloat162*)(&(v_.y)))->y+((nv_bfloat162*)(&(v__1.y)))->y);
((nv_bfloat162*)(&(__1.z)))->x = (((nv_bfloat162*)(&(v_.z)))->x+((nv_bfloat162*)(&(v__1.z)))->x);
((nv_bfloat162*)(&(__1.z)))->y = (((nv_bfloat162*)(&(v_.z)))->y+((nv_bfloat162*)(&(v__1.z)))->y);
((nv_bfloat162*)(&(__1.w)))->x = (((nv_bfloat162*)(&(v_.w)))->x+((nv_bfloat162*)(&(v__1.w)))->x);
((nv_bfloat162*)(&(__1.w)))->y = (((nv_bfloat162*)(&(v_.w)))->y+((nv_bfloat162*)(&(v__1.w)))->y);
*(uint4*)(A_register + 0) = __1;
*(uint4*)(C + ((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 8))) = *(uint4*)(A_register + 0);
}
```
We observed the emergence of two additional registers, `A_register` and `B_register`. However, during the actual computation, these registers are simply reassigned to `v_` and `v__1`, respectively.
To evaluate complexity, one could implement the same elementwise addition operator using CuTe and compare it with the TileLang version. The corresponding CuTe code is provided below:
```c++
template<int NUM_ELE_PER_THREAD=8>
__global__ void elementwise_add(nv_bfloat16* C,
const nv_bfloat16* A,
const nv_bfloat16* B,
int N) {
using namespace cute;
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
Tensor t_C = make_tensor(make_gmem_ptr(C), make_shape(N));
Tensor t_A = make_tensor(make_gmem_ptr(A), make_shape(N));
Tensor t_B = make_tensor(make_gmem_ptr(B), make_shape(N));
Tensor t_C_tile = local_tile(t_C, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor t_A_tile = local_tile(t_A, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor t_B_tile = local_tile(t_B, make_shape(Int<NUM_ELE_PER_THREAD>{}), make_coord(idx));
Tensor reg_buffer_A = make_tensor_like(t_A_tile);
Tensor reg_buffer_B = make_tensor_like(t_B_tile);
Tensor reg_buffer_C = make_tensor_like(t_C_tile);
// LDG. 128
copy(t_A_tile, reg_buffer_A);
copy(t_B_tile, reg_buffer_B);
auto reg_C_vector = recast<nv_bfloat162>(reg_buffer_C);
auto reg_A_vector = recast<nv_bfloat162>(reg_buffer_A);
auto reg_B_vector = recast<nv_bfloat162>(reg_buffer_B);
// Perform vectorized addition
#pragma unroll
for (int vec_idx = 0; vec_idx < size(reg_C_vector); ++vec_idx) {
reg_C_vector(vec_idx) = reg_A_vector(vec_idx) + reg_B_vector(vec_idx);
}
auto reg_C_flat = recast<nv_bfloat16>(reg_C_vector);
// STG. 128
copy(reg_C_flat, t_C_tile);
}
```
## Conclusion
This tutorial showcases the implementation of the elementwise addition operator using TileLang, while also comparing various design approaches. TileLang significantly reduces the complexity of CUDA programming, enabling high performance with minimal code. Nevertheless, working with TileLang demands careful attention to specific implementation details. To ensure computational efficiency, it is essential to thoroughly examine the generated CUDA kernels.
---
**Reference:**
[1] The CuTe code implementation draws inspiration from the techniques discussed in this blog: https://zhuanlan.zhihu.com/p/690703999
# General Matrix-Vector Multiplication (GEMV)
===========================================
<div style="text-align: left;">
<em>Contributor: </em> <a href="https://github.com/botbw">@botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
Example code can be found at `examples/gemv/example_gemv.py`.
:::
General matrix-vector multiplication (GEMV) can be viewed as a specialized case of general matrix-matrix multiplication (GEMM). It plays a critical role in deep learning, especially during the inference phase of large language models. In this tutorial, we will optimize GEMV from a thread-level perspective step by step using `TileLang`.
## Triton Implementation
When implementing a GEMV kernel, you might start with a high-level approach using a tool like `Triton`.
A simple Triton kernel for GEMV might look like this:
```python
@triton.jit
def _gemv_naive(
x_ptr, A_ptr, y_ptr,
N, K,
BLOCK_SIZE_K: tl.constexpr,
):
n = tl.program_id(0)
offs_k = tl.arange(0, BLOCK_SIZE_K)
mask = offs_k < K
a_ptrs = A_ptr + n * K + offs_k
a_vals = tl.load(a_ptrs, mask=mask, other=0.0)
x_vals = tl.load(x_ptr + offs_k, mask=mask, other=0.0)
dot = tl.sum(a_vals * x_vals, axis=0)
tl.store(y_ptr + n, dot)
```
`Triton` is straightforward to use, as it operates at the block level. However, this approach may not allow for fine-grained thread-level optimization. In this tutorial, we will demonstrate how to write an optimized GEMV kernel in `TileLang` that exposes more low-level control.
## Naive Implementation in TileLang
If you have a basic understanding of CUDA C, it is natural to start with a naive GEMV kernel by adapting a GEMM tiling strategy. You can think of GEMV as a `(1, k) * (k, n)` GEMM. Below is a simple example:
```python
def naive_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N)) as bn:
tn = T.get_thread_binding(0) # tn = threadIdx.x
A_shared = T.alloc_shared((BLOCK_K,), dtype)
B_shared = T.alloc_shared((BLOCK_N, BLOCK_K), dtype)
C_reg = T.alloc_local((1,), accum_dtype)
T.clear(C_reg)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for tk in T.serial(BLOCK_K):
A_shared[tk] = A[bk * BLOCK_K + tk]
B_shared[tn, tk] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
for tk in T.serial(BLOCK_K):
C_reg[0] += A_shared[tk].astype(accum_dtype) * B_shared[tn,
tk].astype(accum_dtype)
C[bn * BLOCK_N + tn] = C_reg[0]
return main
```
And your kernel will be compiled into CUDA by `TileLang` (in `~/.tilelang/cache`):
```C++
extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_reg[1];
__shared__ uint64_t _mbarrier[2];
if (((int)threadIdx.x) == 0) {
tl::mbarrier_init(_mbarrier[0], 128);
tl::mbarrier_init(_mbarrier[1], 128);
}
__syncthreads();
if (128 <= ((int)threadIdx.x)) {
tl::warpgroup_reg_dealloc<24>();
for (int bk = 0; bk < 8; ++bk) {
tl::mbarrier_wait(_mbarrier[1], ((bk & 1) ^ 1));
for (int tk = 0; tk < 128; ++tk) {
((half_t*)buf_dyn_shmem)[tk] = A[((bk * 128) + tk)];
((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk) - 16256)] = B[(((((((int)blockIdx.x) * 131072) + (((int)threadIdx.x) * 1024)) + (bk * 128)) + tk) - 131072)];
}
tl::fence_proxy_async();
tl::mbarrier_cp_async_arrive(_mbarrier[0]);
tl::mbarrier_arrive(_mbarrier[0]);
}
} else {
tl::warpgroup_reg_alloc<240>();
C_reg[0] = 0.000000e+00f;
for (int bk_1 = 0; bk_1 < 8; ++bk_1) {
tl::mbarrier_wait(_mbarrier[0], (bk_1 & 1));
for (int tk_1 = 0; tk_1 < 128; ++tk_1) {
C_reg[0] = (C_reg[0] + (((float)((half_t*)buf_dyn_shmem)[tk_1]) * ((float)((half_t*)buf_dyn_shmem)[(((((int)threadIdx.x) * 128) + tk_1) + 128)])));
}
tl::fence_proxy_async();
tl::mbarrier_arrive(_mbarrier[1]);
}
C[((((int)blockIdx.x) * 128) + ((int)threadIdx.x))] = ((half_t)C_reg[0]);
}
}
```
In this design, the first 128 threads act as the data producer and the last 128 threads as the consumer within a block (assuming a 1D block).
At this level, we only gain very little computation power from our GPU with around **~0.17 ms** compared to torch/cuBLAS's **~0.008 ms**, which is around 20x slower.
## More Concurrency
To further increase the concurrency of our kernel, we can exploit finer thread-level parallelism. Instead of assigning each thread to compute a single output element in C, you can introduce parallelism along the K dimension. Each thread computes a partial accumulation, and you then combine these partial results. This approach requires primitives like `atomicAdd` in CUDA.
Here’s a simplified version:
```python
def naive_splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, BLOCK_K)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((1,), dtype)
B_local = T.alloc_local((1,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
A_local[0] = A[bk * BLOCK_K + tk]
B_local[0] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk]
C_accum[0] += A_local[0].astype(accum_dtype) * B_local[0].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
By introducing parallelism along K dimension, our kernel now achieves **~0.024 ms**, an improvement, but still not on par with torch/cuBLAS.
### Customizing Parallelism in K Dimension
If your K dimension is large, you can further customize how many elements each thread processes by introducing a `reduce_threads` parameter. This way, each thread handles multiple elements per iteration:
```python
def splitk_gemv(
N: int,
K: int,
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.serial(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
## Vectorized Reads
GEMV is less computation intensive than GEMM as the computation intensity and memory throughput will be the optimization bottleneck. One effective strategy is to use vectorized load/store operations (e.g., `float2`, `float4`). In `TileLang`, you can specify vectorized operations via `T.vectorized`:
```python
def splitk_gemv_vectorized(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_shared = T.alloc_shared((BLOCK_N,), accum_dtype)
C_accum = T.alloc_local((1,), accum_dtype)
if tk == 0:
C_shared[tn] = 0
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
T.atomic_add(C_shared[tn], C_accum[0])
C[bn * BLOCK_N + tn] = C_shared[tn]
return main
```
With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd`
[`tvm_thread_allreduce`](https://tvm.apache.org/docs/reference/api/python/tir/tir.html#tvm.tir.tvm_thread_allreduce) has implemented optimization when making an all-reduce across a number of threads, which should outperfrom out plain smem + `atomidAdd`:
```python
def splitk_gemv_vectorized_tvm(
N: int,
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
```
With this optimization, the kernel latency now reduces from **~0.0084 ms** to **~0.0069 ms**, which is faster than torch/cuBLAS!
## Autotune
`BLOCK_N`, `BLOCK_K`, `reduce_threads` are hyperparameters in our kernel, which can be tuned to improve performance. We can use the `tilelang.autotune` feature to automatically search for optimal configurations:
```python
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune(
configs=get_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Buffer((K,), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
tn = T.get_thread_binding(0)
tk = T.get_thread_binding(1)
A_local = T.alloc_local((TILE_K,), dtype)
B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
T.clear(C_accum)
for bk in T.serial(T.ceildiv(K, BLOCK_K)):
for k in T.vectorized(TILE_K):
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
for k in T.serial(TILE_K):
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
C_reduced = T.alloc_local((1,), accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
```
After autotuning, now our kernel gets **~0.0067 ms**, the final generated CUDA kernel might like this:
```C++
extern "C" __global__ void __launch_bounds__(64, 1) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
float C_accum[1];
half_t A_local[8];
half_t B_local[8];
__shared__ float red_buf0[64];
C_accum[0] = 0.000000e+00f;
for (int bk = 0; bk < 4; ++bk) {
*(uint4*)(A_local + 0) = *(uint4*)(A + ((bk * 256) + (((int)threadIdx.y) * 8)));
*(uint4*)(B_local + 0) = *(uint4*)(B + ((((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 1024)) + (bk * 256)) + (((int)threadIdx.y) * 8)));
for (int k = 0; k < 8; ++k) {
C_accum[0] = (C_accum[0] + (((float)A_local[k]) * ((float)B_local[k])));
}
}
tl::fence_proxy_async();
__syncthreads();
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = C_accum[0];
__syncthreads();
if (((int)threadIdx.y) < 16) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 16)]);
}
__syncthreads();
if (((int)threadIdx.y) < 8) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 8)]);
}
__syncthreads();
if (((int)threadIdx.y) < 4) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 4)]);
}
__syncthreads();
if (((int)threadIdx.y) < 2) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 2)]);
}
__syncthreads();
if (((int)threadIdx.y) < 1) {
red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] = (red_buf0[((((int)threadIdx.x) * 32) + ((int)threadIdx.y))] + red_buf0[(((((int)threadIdx.x) * 32) + ((int)threadIdx.y)) + 1)]);
}
__syncthreads();
C[((((int)blockIdx.x) * 2) + ((int)threadIdx.x))] = ((half_t)red_buf0[(((int)threadIdx.x) * 32)]);
}
```
This corresponds closely to our `TileLang` program, with necessary synchronization and low-level optimizations inserted automatically.
## Conclusion
### Benchmark Table on Hopper GPU
| Kernel Name | Latency |
|------------|------------|
| torch/cuBLAS | 0.00784 ms |
| Triton | 0.00773 ms |
| naive_gemv | 0.16607 ms |
| splitk_gemv | 0.02419 ms |
| splitk_gemv_vectorized | 0.00809 ms |
| splitk_gemv_vectorized_tvm | 0.00675 ms |
Triton Time: 0.0077344514429569244
In this tutorial, we implemented a simple GEMV kernel and learn that `TileLang` exposes low level control to user such as thread-level programming and CUDA primitives.
\ No newline at end of file
# General Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
:::{warning}
:class: myclass1 myclass2
:name: a-tip-reference
This document is still **experimental** and may be incomplete.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
TileLang is a domain-specific language (DSL) designed for writing high-performance GPU kernels. It provides three main levels of abstraction:
* **Level 1:** A user writes pure compute logic without knowledge of or concern for hardware details (e.g., GPU caches, tiling, etc.). The compiler or runtime performs automatic scheduling and optimization. This level is conceptually similar to the idea behind TVM.
* **Level 2:** A user is aware of GPU architecture concepts—such as shared memory, tiling, and thread blocks—but does not necessarily want to drop down to the lowest level of explicit thread control. This mode is somewhat comparable to Triton's programming model, where you can write tile-level operations and let the compiler do layout inference, pipelining, etc.
* **Level 3:** A user takes full control of thread-level primitives and can write code that is almost as explicit as a hand-written CUDA/HIP kernel. This is useful for performance experts who need to manage every detail, such as PTX inline assembly, explicit thread behavior, etc.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
In this tutorial, we introduce Level 2 with a matrix multiplication example in TileLang. We will walk through how to allocate shared memory, set up thread blocks, perform parallel copying, pipeline the computation, and invoke the tile-level GEMM intrinsic. We will then show how to compile and run the kernel in Python, comparing results and measuring performance.
## Why Another GPU DSL?
TileLang emerged from the need for a DSL that:
1. Balances high-level expressiveness (like TVM or Triton) with enough flexibility to control finer details when needed.
2. Supports efficient code generation and scheduling for diverse hardware backends (NVIDIA GPUs, AMD GPUs, CPU, etc.).
3. Simplifies scheduling and memory pipelines with built-in primitives (such as `T.Pipelined`, `T.Parallel`, `T.gemm`), yet retains options for expert-level tuning.
While Level 1 in TileLang can be very comfortable for general users—since it requires no scheduling or hardware-specific knowledge—it can incur longer auto-tuning times and may not handle some complex kernel fusion patterns (e.g., Flash Attention) as easily. Level 3 gives you full control but demands more effort, similar to writing raw CUDA/HIP kernels. Level 2 thus strikes a balance for users who want to write portable and reasonably concise code while expressing important architectural hints.
## Matrix Multiplication Example
```{figure} ../_static/img/MatmulExample.png
:alt: Matmul Example
:align: center
```
### Basic Structure
Below is a simplified code snippet for a 1024 x 1024 x 1024 matrix multiplication. It uses:
* **`T.Kernel(...)`** to initialize the thread block configuration (grid dimensions, block size, etc.).
* **`T.alloc_shared(...)`** to allocate GPU shared memory.
* **`T.alloc_fragment(...)`** to allocate a register fragment for accumulation.
* **`T.Pipelined(...)`** to express software pipelining across the K dimension.
* **`T.Parallel(...)`** to parallelize data copy loops.
* **`T.gemm(...)`** to perform tile-level GEMM operations (which map to the appropriate backends, such as MMA instructions on NVIDIA GPUs).
```python
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Optional layout hints (commented out by default)
# T.annotate_layout({
# A_shared: make_mma_swizzle_layout(A_shared),
# B_shared: make_mma_swizzle_layout(B_shared),
# })
# Optional: Enabling swizzle-based rasterization
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A from global to shared memory
T.copy(A[by * block_M, ko * block_K], A_shared)
# Parallel copy tile of B from global to shared memory
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Perform a tile-level GEMM
T.gemm(A_shared, B_shared, C_local)
# Copy result from local (register fragment) to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Create the TileLang function
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
import torch
# 3. Prepare input tensors in PyTorch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# 4. Invoke the JIT-compiled kernel
c = jit_kernel(a, b)
ref_c = a @ b
# 5. Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 6. Inspect generated CUDA code (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# 7. Profile performance
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
```
### Key Concepts
1. **Kernel Context**:
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...
```
- This sets up the block grid dimensions based on N/block_N and M/block_M.
- `threads=128` specifies that each thread block uses 128 threads. The compiler will infer how loops map to these threads.
```{figure} ../_static/img/Parallel.png
:alt: Parallel
:align: center
```
2. **Shared & Fragment Memory**:
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
- `T.alloc_shared` allocates shared memory across the entire thread block.
- `T.alloc_fragment` allocates register space for local accumulation. Though it is written as `(block_M, block_N)`, the compiler’s layout inference assigns slices of this space to each thread.
3. **Software Pipelining**:
```python
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
...
```
- `T.Pipelined` automatically arranges asynchronous copy and compute instructions to overlap memory operations with arithmetic.
- The argument `num_stages=3` indicates the pipeline depth.
```{figure} ../_static/img/software_pipeline_inference.png
:alt: Software Pipeline Inference
:align: center
```
4. **Parallel Copy**:
```python
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
```
- `T.Parallel` marks the loop for thread-level parallelization.
- The compiler will map these loops to the available threads in the block.
5. **Tile-Level GEMM**:
```python
T.gemm(A_shared, B_shared, C_local)
```
- A single call that performs a tile-level matrix multiplication using the specified buffers.
- Under the hood, for NVIDIA targets, it can use CUTLASS/Cute or WMMA instructions. On AMD GPUs, TileLang uses a separate HIP or composable kernel approach.
6. **Copying Back Results**:
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
- After computation, data in the local register fragment is written back to global memory.
## Comparison with Other DSLs
TileLang Level 2 is conceptually similar to Triton in that the user can control tiling and parallelization, while letting the compiler handle many low-level details. However, TileLang also:
- Allows explicit memory layout annotations (e.g. `make_mma_swizzle_layout`).
- Supports a flexible pipeline pass (`T.Pipelined`) that can be automatically inferred or manually defined.
- Enables mixing different levels in a single program—for example, you can write some parts of your kernel in Level 3 (thread primitives) for fine-grained PTX/inline-assembly and keep the rest in Level 2.
## Performance on Different Platforms
```{figure} ../_static/img/op_benchmark_consistent_gemm_fp16.png
:alt: Performance on Different Platforms
:align: center
```
When appropriately tuned (e.g., by using an auto-tuner), TileLang achieves performance comparable to or better than vendor libraries and Triton on various GPUs. In internal benchmarks, for an FP16 matrix multiply (e.g., 4090, A100, H100, MI300X), TileLang has shown:
- ~1.1x speedup over cuBLAS on RTX 4090
- ~0.97x on A100 (on par with cuBLAS)
- ~1.0x on H100
- ~1.04x on MI300X
- Compared to Triton, speedups range from 1.08x to 1.25x depending on the hardware.
These measurements will vary based on tile sizes, pipeline stages, and the hardware’s capabilities.
## Conclusion
This tutorial demonstrated a Level 2 TileLang kernel for matrix multiplication. With just a few lines of code:
1. We allocated shared memory and register fragments.
2. We pipelined the loading and computation along the K dimension.
3. We used parallel copying to efficiently load tiles from global memory.
4. We invoked `T.gemm` to dispatch a tile-level matrix multiply.
5. We verified correctness against PyTorch and examined performance.
By balancing high-level abstractions (like `T.copy`, `T.Pipelined`, `T.gemm`) with the ability to annotate layouts or drop to thread primitives (Level 3) when needed, TileLang can be both user-friendly and highly tunable. We encourage you to experiment with tile sizes, pipeline depths, or explicit scheduling to see how performance scales across different GPUs.
For more advanced usage—including partial lowering, explicitly controlling thread primitives, or using inline assembly—you can explore Level 3. Meanwhile, for purely functional expressions and high-level scheduling auto-tuning, consider Level 1.
## Further Resources
* [TileLang GitHub](https://github.com/tile-ai/tilelang)
* [BitBLAS](https://github.com/tile-ai/bitblas)
* [Triton](https://github.com/openai/triton)
* [Cutlass](https://github.com/NVIDIA/cutlass)
* [PyCUDA](https://documen.tician.de/pycuda/) <!-- codespell:ignore -->
# Installation Guide
## Installing with pip
**Prerequisites for installation via wheel or PyPI:**
- **glibc**: 2.28 (Ubuntu 20.04 or later)
- **Python Version**: >= 3.8
- **CUDA Version**: 12.0 <= CUDA < 13
The easiest way to install **tile-lang** is directly from PyPI using pip. To install the latest version, run the following command in your terminal:
```bash
pip install tilelang
```
Alternatively, you may choose to install **tile-lang** using prebuilt packages available on the Release Page:
```bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```
To install the latest version of **tile-lang** from the GitHub repository, you can run the following command:
```bash
pip install git+https://github.com/tile-ai/tilelang.git
```
After installing **tile-lang**, you can verify the installation by running:
```bash
python -c "import tilelang; print(tilelang.__version__)"
```
## Building from Source
**Prerequisites for building from source:**
- **Operating System**: Linux
- **Python Version**: >= 3.8
- **CUDA Version**: >= 10.0
```bash
docker run -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```
To build and install **tile-lang** directly from source, follow these steps. This process requires certain pre-requisites from Apache TVM, which can be installed on Ubuntu/Debian-based systems using the following commands:
```bash
apt-get update
apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev
```
After installing the prerequisites, you can clone the **tile-lang** repository and install it using pip:
```bash
git clone --recursive https://github.com/tile-ai/tilelang.git
cd tilelang
pip install . -v
```
If you want to install **tile-lang** in development mode, you can run the following command:
```bash
pip install -e . -v
```
If you prefer to work directly from the source tree via `PYTHONPATH`, make sure the native extension is built first:
```bash
mkdir -p build
cd build
cmake .. -DUSE_CUDA=ON
make -j
```
Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example:
```bash
export PYTHONPATH=/path/to/tilelang:$PYTHONPATH
python -c "import tilelang; print(tilelang.__version__)"
```
Some useful CMake options you can toggle while configuring:
- `-DUSE_CUDA=ON|OFF` builds against NVIDIA CUDA (default ON when CUDA headers are found).
- `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs.
- `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`.
We currently provide four methods to install **tile-lang**:
1. [Install Using Docker](#install-method-1) (Recommended)
2. [Install from Source (using the bundled TVM submodule)](#install-method-2)
3. [Install from Source (using your own TVM installation)](#install-method-3)
(install-method-1)=
### Method 1: Install Using Docker (Recommended)
For users who prefer a containerized environment with all dependencies pre-configured, **tile-lang** provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems and is the **recommended approach** for most users.
**Prerequisites:**
- Docker installed on your system
- NVIDIA Docker runtime or GPU is not necessary for building tilelang, you can build on a host without GPU and use that built image on other machine.
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
2. **Build Docker Image**:
Navigate to the docker directory and build the image for your desired CUDA version:
```bash
cd docker
docker build -f Dockerfile.cu120 -t tilelang-cu120 .
```
Available Dockerfiles:
- `Dockerfile.cu120` - For CUDA 12.0
- Other CUDA versions may be available in the docker directory
3. **Run Docker Container**:
Start the container with GPU access and volume mounting:
```bash
docker run -itd \
--shm-size 32g \
--gpus all \
-v /home/tilelang:/home/tilelang \
--name tilelang_b200 \
tilelang-cu120 \
/bin/zsh
```
**Command Parameters Explanation:**
- `--shm-size 32g`: Increases shared memory size for better performance
- `--gpus all`: Enables access to all available GPUs
- `-v /home/tilelang:/home/tilelang`: Mounts host directory to container (adjust path as needed)
- `--name tilelang_b200`: Assigns a name to the container for easy management
- `/bin/zsh`: Uses zsh as the default shell
4. **Access the Container**:
```bash
docker exec -it tilelang_b200 /bin/zsh
```
5. **Verify Installation**:
Once inside the container, verify that **tile-lang** is working correctly:
```bash
python -c "import tilelang; print(tilelang.__version__)"
```
You can now run TileLang examples and develop your applications within the containerized environment. The Docker image comes with all necessary dependencies pre-installed, including CUDA toolkit, TVM, and TileLang itself.
**Example Usage:**
After accessing the container, you can run TileLang examples:
```bash
cd /home/tilelang/examples
python elementwise/test_example_elementwise.py
```
This Docker-based installation method provides a complete, isolated environment that works seamlessly on systems with compatible NVIDIA GPUs like the B200, ensuring optimal performance for TileLang applications.
(install-method-2)=
### Method 2: Install from Source (Using the Bundled TVM Submodule)
If you already have a compatible TVM installation, follow these steps:
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Use the `--recursive` flag to include necessary submodules.
2. **Configure Build Options**:
Create a build directory and specify your existing TVM path:
```bash
pip install . -v
```
(install-method-3)=
### Method 3: Install from Source (Using Your Own TVM Installation)
If you prefer to use the built-in TVM version, follow these instructions:
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Ensure the `--recursive` flag is included to fetch submodules.
2. **Configure Build Options**:
Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA):
```bash
TVM_ROOT=<your-tvm-repo> pip install . -v
```
## Install with Nightly Version
For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**.
```bash
pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/
```
> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.
## Install Configs
### Build-time environment variables
`USE_CUDA`: If to enable CUDA support, default: `ON` on Linux, set to `OFF` to build a CPU version. By default, we'll use `/usr/local/cuda` for building tilelang. Set `CUDAToolkit_ROOT` to use different cuda toolkit.
`USE_ROCM`: If to enable ROCm support, default: `OFF`. If your ROCm SDK does not located in `/opt/rocm`, set `USE_ROCM=<rocm_sdk>` to enable build ROCm against custom sdk path.
`USE_METAL`: If to enable Metal support, default: `ON` on Darwin.
`TVM_ROOT`: TVM source root to use.
`NO_VERSION_LABEL` and `NO_TOOLCHAIN_VERSION`:
When building tilelang, we'll try to embed SDK and version information into package version as below,
where local version label could look like `<sdk>.git<git_hash>`. Set `NO_VERSION_LABEL=ON` to disable this behavior.
```
$ python -mbuild -w
...
Successfully built tilelang-0.1.6.post1+cu116.git0d4a74be-cp38-abi3-linux_x86_64.whl
```
where `<sdk>={cuda,rocm,metal}`. Specifically, when `<sdk>=cuda` and `CUDA_VERSION` is provided via env,
`<sdk>=cu<cuda_major><cuda_minor>`, similar with this part in pytorch.
Set `NO_TOOLCHAIN_VERSION=ON` to disable this.
### Run-time environment variables
<!-- TODO: tvm -->
## IDE Configs
Building tilelang locally will automatically `compile_commands.json` file in `build` dir.
VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration.
## Compile cache
`ccache` will be automatically used if found.
## Repairing wheels
If you plan to use your wheel in other environment,
it's recommend to use auditwheel (on Linux) or delocate (on Darwin)
to repair them.
## Faster rebuild for developers
`pip install` introduces extra [un]packaging and takes ~30 sec to complete,
even if no source change.
Developers who needs to recompile frequently could use:
```bash
pip install -r requirements-dev.txt
pip install -e . -v --no-build-isolation
cd build; ninja
```
When running in editable/developer mode,
you'll see logs like below:
```console
$ python -c 'import tilelang'
2025-10-14 11:11:29 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /Users/yyc/repo/tilelang/build
```
# The Tile Language: A Brief Introduction
## Programming Interface
The figure below depicts how **TileLang** programs are progressively lowered from a high-level description to hardware-specific executables. We provide three different programming interfaces—targeted at **Beginner**, **Developer**, and **Expert** users—that each reside at different levels in this lowering pipeline. The **Tile Language** also allows mixing these interfaces within the same kernel, enabling users to work at whichever level of abstraction best suits their needs.
```{figure} ../_static/img/overview.png
:width: 50%
:alt: Overview
:align: center
Figure 1: High-level overview of the TileLang compilation flow.
```
## Programming Interfaces
1. **Beginner Level (Hardware-Unaware)**
- Intended for users who need to write code that is independent of specific hardware details.
- The goal is to let developers focus on the basic logic without worrying about memory hierarchies or hardware-specific optimizations.
- *Note:* This interface is not yet fully implemented.
2. **Developer Level (Hardware-Aware with Tile Library)**
- Designed for developers who have a basic understanding of GPU memory hierarchies and performance considerations.
- Provides a **Tile Library**, containing predefined operations and patterns optimized for various hardware architectures.
- Users at this level can leverage these ready-made primitives without diving into low-level threading details.
3. **Expert Level (Hardware-Aware with Thread Primitives)**
- For highly experienced users who have an in-depth understanding of low-level hardware characteristics (e.g., threading models, memory coalescing).
- Offers direct access to **thread primitives** and other low-level constructs, allowing for fine-grained control of performance-critical kernels.
- This level grants maximum flexibility for specialized optimizations tailored to specific GPU or multi-core architectures.
## Compilation Flow
1. **Tile Program**
A high-level specification of the computation. Depending on the user’s expertise, they may write a purely hardware-unaware tile program or incorporate constructs from the Tile Library or thread primitives.
2. **Tile Program with Tile Library**
When developers choose from the Tile Library, the original Tile Program is expanded with specialized library calls. These calls encapsulate efficient implementation patterns for different operations.
3. **Tile Program with Thread Primitives**
Expert-level developers can explicitly use low-level threading constructs to hand-optimize data layout, synchronization, and memory usage.
4. **IRModule**
After the program is composed with libraries or thread primitives, it is lowered to an intermediate representation (IR) that captures the necessary hardware details.
5. **Source Code Generation (C/CUDA/HIP/LLVM/…)**
From the IR, the system generates target-specific source code. This source code is tuned for the desired backends or GPU architectures (e.g., NVIDIA, AMD).
6. **Hardware-Specific Executable/Runtime**
Finally, the generated source is compiled into hardware-specific executables, ready to run on the corresponding devices. The pipeline supports multiple GPU backends and can be extended to additional architectures.
## Tile-based Programming Model
[Figure 2](#fig-overview-gemm) provides a concise matrix multiplication (GEMM) example in ``TileLang``,
illustrating how developers can employ high-level constructs such as tiles, memory placement, pipelining,
and operator calls to manage data movement and computation with fine-grained control.
In particular, this snippet ([Figure 2](#fig-overview-gemm) (a)) demonstrates how multi-level tiling
leverages different memory hierarchies (global, shared, and registers) to optimize bandwidth utilization
and reduce latency.
Overall, [Figure 2](#fig-overview-gemm) (b) showcases how the Python-like syntax of ``TileLang``
allows developers to reason about performance-critical optimizations within a user-friendly programming model.
```{figure} ../_static/img/MatmulExample.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
:name: fig-overview-gemm
Figure 2: Optimizing GEMM with Multi-Level Tiling on GPUs via ``TileLang``.
```
### Tile declarations
At the heart of our approach is the notion of *tiles* as first-class objects in the programming model. A tile represents a shaped portion of data, which can be owned and manipulated by a warp, thread block, or equivalent parallel unit. In the `Matmul` example, the `A` and `B` buffers are read in tiled chunks (determined by `block_M`, `block_N`, `block_K`) inside the kernel loop. With `T.Kernel`, `TileLang` defines the execution context, which includes the thread block index (`bx` and `by`) and the number of threads. These contexts can help compute the index for each thread block and make it easier for `TileLang` to automatically infer and optimize memory access and computation. Additionally, these contexts allow users to manually control the behavior of each independent thread within a thread block.
### Explicit Hardware Memory Allocation
A hallmark of `TileLang` is the ability to explicitly place these tile buffers in the hardware memory hierarchy. Rather than leaving it to a compiler's opaque optimization passes, `TileLang` exposes user-facing intrinsics that map directly to physical memory spaces or accelerator-specific constructs. In particular:
- `T.alloc_shared`: Allocates memory in a fast, on-chip storage space, which corresponds to shared memory on NVIDIA GPUs. Shared memory is ideal for caching intermediate data during computations, as it is significantly faster than global memory and allows for efficient data sharing between threads in the same thread block. For example, in matrix multiplication, tiles of matrices can be loaded into shared memory to reduce global memory bandwidth demands and improve performance.
- `T.alloc_fragment`: Allocates accumulators in fragment memory, which corresponds to register files on NVIDIA GPUs. By keeping inputs and partial sums in registers or hardware-level caches, latency is further minimized. Note that in this tile program, each tile allocates the same local buffers as shared memory, which might seem counterintuitive, as shared memory is generally faster but more abundant, whereas register file space is limited. This is because the allocation here refers to the register files for an entire thread block. `TileLang` uses a Layout Inference Pass during compilation to derive a Layout object `T.Fragment`, which determines how to allocate the corresponding register files for each thread. This process will be discussed in detail in subsequent sections.
Data transfer between global memory and hardware-specific memory can be managed using `T.copy`. Furthermore, hardware-specific buffers can be initialized using `T.clear` or `T.fill`. For data assignments, operations can also be performed in parallel using `T.Parallel`, as demonstrated in Layout Inference Pass in the following sections.
```{figure} ../_static/img/LayoutInference.png
:align: center
:width: 100%
:alt: GEMM with Multi-Level Tiling on GPUs
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment