Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
# Understanding Targets
TileLang is built on top of TVM, which relies on **targets** to describe the device you want to compile for.
The target determines which code generator is used (CUDA, HIP, Metal, LLVM, …) and allows you to pass
device-specific options such as GPU architecture flags. This page summarises how to pick and customise a target
when compiling TileLang programs.
## Common target strings
TileLang ships with a small set of common targets; each accepts the full range of TVM options so you can fine-tune
the generated code. The most frequent choices are listed below:
| Base name | Description |
| --------- | ----------- |
| `auto` | Detects CUDA → HIP → Metal in that order. Useful when running the same script across machines. |
| `cuda` | NVIDIA GPUs. Supports options such as `-arch=sm_80`, `-max_num_threads=1024`, etc. |
| `hip` | AMD GPUs via ROCm. Options like `-mcpu=gfx90a` can be appended. |
| `metal` | Apple Silicon GPUs (arm64 Macs). |
| `llvm` | CPU execution; accepts the standard TVM LLVM switches. |
| `webgpu` | Browser / WebGPU runtimes. |
| `c` | Emit plain C source for inspection or custom toolchains. |
To add options, append them after the base name, separated by spaces. For example:
```python
target = "cuda -arch=sm_90"
kernel = tilelang.compile(func, target=target, execution_backend="cython")
# or
@tilelang.jit(target=target)
def compiled_kernel(*args):
return func(*args)
```
The same convention works for HIP or LLVM (e.g. `hip -mcpu=gfx940`, `llvm -mtriple=x86_64-linux-gnu`).
### Advanced: Specify Exact Hardware
When you already know the precise GPU model, you can encode it in the target string—either via `-arch=sm_XX` or by
using one of TVM’s pre-defined target tags such as `nvidia/nvidia-h100`. Supplying this detail is optional for
TileLang in general use, but it becomes valuable when the TVM cost model is enabled (e.g. during autotuning). The
cost model uses the extra attributes to make better scheduling predictions. If you skip this step (or do not use the
cost model), generic targets like `cuda` or `auto` are perfectly fine.
All CUDA compute capabilities recognised by TVM’s target registry are listed below. Pick the one that matches your
GPU and append it to the target string or use the corresponding target tag—for example `nvidia/nvidia-a100`.
| Architecture | GPUs (examples) |
| ------------ | ---------------- |
| `sm_20` | `nvidia/tesla-c2050`, `nvidia/tesla-c2070` |
| `sm_21` | `nvidia/nvs-5400m`, `nvidia/geforce-gt-520` |
| `sm_30` | `nvidia/quadro-k5000`, `nvidia/geforce-gtx-780m` |
| `sm_35` | `nvidia/tesla-k40`, `nvidia/quadro-k6000` |
| `sm_37` | `nvidia/tesla-k80` |
| `sm_50` | `nvidia/quadro-k2200`, `nvidia/geforce-gtx-950m` |
| `sm_52` | `nvidia/tesla-m40`, `nvidia/geforce-gtx-980` |
| `sm_53` | `nvidia/jetson-tx1`, `nvidia/jetson-nano` |
| `sm_60` | `nvidia/tesla-p100`, `nvidia/quadro-gp100` |
| `sm_61` | `nvidia/tesla-p4`, `nvidia/quadro-p6000`, `nvidia/geforce-gtx-1080` |
| `sm_62` | `nvidia/jetson-tx2` |
| `sm_70` | `nvidia/nvidia-v100`, `nvidia/quadro-gv100` |
| `sm_72` | `nvidia/jetson-agx-xavier` |
| `sm_75` | `nvidia/nvidia-t4`, `nvidia/quadro-rtx-8000`, `nvidia/geforce-rtx-2080` |
| `sm_80` | `nvidia/nvidia-a100`, `nvidia/nvidia-a30` |
| `sm_86` | `nvidia/nvidia-a40`, `nvidia/nvidia-a10`, `nvidia/geforce-rtx-3090` |
| `sm_87` | `nvidia/jetson-agx-orin-32gb`, `nvidia/jetson-agx-orin-64gb` |
| `sm_89` | `nvidia/geforce-rtx-4090` |
| `sm_90a` | `nvidia/nvidia-h100` (DPX profile) |
| `sm_100a` | `nvidia/nvidia-b100` |
Refer to NVIDIA’s [CUDA GPUs](https://developer.nvidia.com/cuda-gpus) page or the TVM source
(`3rdparty/tvm/src/target/tag.cc`) for the latest mapping between devices and compute capabilities.
## Creating targets programmatically
If you prefer working with TVM’s `Target` objects, TileLang exposes the helper
`tilelang.utils.target.determine_target` (returns a canonical target string by default, or the `Target`
object when `return_object=True`):
```python
from tilelang.utils.target import determine_target
tvm_target = determine_target("cuda -arch=sm_80", return_object=True)
kernel = tilelang.compile(func, target=tvm_target)
```
You can also build targets directly through TVM:
```python
from tvm.target import Target
target = Target("cuda", host="llvm")
target = target.with_host(Target("llvm -mcpu=skylake"))
```
TileLang accepts either `str` or `Target` inputs; internally they are normalised and cached using the canonical
string representation. **In user code we strongly recommend passing target strings rather than
`tvm.target.Target` instances—strings keep cache keys compact and deterministic across runs, whereas constructing
fresh `Target` objects may lead to slightly higher hashing overhead or inconsistent identity semantics.**
## Discovering supported targets in code
Looking for a quick reminder of the built-in base names and their descriptions? Use:
```python
from tilelang.utils.target import describe_supported_targets
for name, doc in describe_supported_targets().items():
print(f"{name:>6}: {doc}")
```
This helper mirrors the table above and is safe to call at runtime (for example when validating CLI arguments).
## Troubleshooting tips
- If you see `Target cuda -arch=sm_80 is not supported`, double-check the spellings and that the option is valid for
TVM. Any invalid switch will surface as a target-construction error.
- Runtime errors such as “no kernel image is available” usually mean the `-arch` flag does not match the GPU you are
running on. Try dropping the flag or switching to the correct compute capability.
- When targeting multiple environments, use `auto` for convenience and override with an explicit string only when
you need architecture-specific tuning.
# 👋 Welcome to Tile Language
[GitHub](https://github.com/tile-ai/tilelang)
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
low-level optimizations necessary for state-of-the-art performance.
:::{toctree}
:maxdepth: 2
:caption: GET STARTED
get_started/Installation
get_started/overview
get_started/targets
:::
:::{toctree}
:maxdepth: 1
:caption: TUTORIALS
tutorials/debug_tools_for_tilelang
tutorials/auto_tuning
:::
:::{toctree}
:maxdepth: 1
:caption: DEEP LEARNING OPERATORS
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
deeplearning_operators/deepseek_mla
:::
:::{toctree}
:maxdepth: 1
:caption: COMPILER INTERNALS
compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
:::
:::{toctree}
:maxdepth: 1
:caption: API Reference
autoapi/tilelang/index
:::
:::{toctree}
:maxdepth: 1
:caption: Privacy
privacy
:::
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
# Privacy
All data stays in users' device and is not collected by the app.
fastapi
pydantic
sphinx
sphinx-reredirects
sphinx-tabs
sphinx-toolbox
sphinxcontrib-napoleon
sphinxcontrib_httpdomain
furo
uvicorn
myst-parser
sphinx-autoapi == 3.6.0
astroid < 4
cancelled
hsa
ist
LOD
nd
NotIn
offen
te
Auto-Tuning Techniques for Performance Optimization
===================================================
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/yyttt6">yyttt6</a>
</div>
## Overview
Auto-tuning a Tile Language program involves three main steps:
1. Implement the target program using Tile Language with reserved optimization parameters
2. ​Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations)
3. Parallel compile and benchmark candidate configurations to identify the best performance
## Matrix Multiplication Example
The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation.
### Step 1: Implement with Reserved Parameters
Users can implement matrix multiplication in Tile Language while reserving parameters for optimization:
```python
# Reserved parameters for optimization
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
# Matrix multiplication implementation
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
# ...existing code...
return main
```
### Step 2: Generate Candidate Configurations
Manually define configurations or use combinatorial generation:
```python
configs = [
{
"block_M": 128,
"block_N": 128,
"block_K": 128,
"num_stages": 3,
"thread_num": 128,
"enable_rasteration": True
},
{
"block_M": 32,
"block_N": 32,
"block_K": 32,
"num_stages": 0,
"thread_num": 32,
"enable_rasteration": False
},
# ...additional configurations...
]
```
It can also be given by combinatorial traversal of different parameters
```python
import itertools
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5]
} for c in _configs
]
```
### Step 3: Compile and Benchmark
Configure JIT compilation and benchmarking settings:
```python
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
result = autotuner.run(warmup=3, rep=20)
out_c = result.kernel(a, b)
```
The result object contains optimized kernel implementation which can be used by users directly
## Using Carver to Auto-Generate Candidate Configurations
Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.
or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`):
```python
# Configure Matmul template
arch = CUDA("cuda")
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
# Generate top-k optimization hints (topk=10 recommended)
roller_hints = carve_template.recommend_hints(topk=10)
# Configure candidate parameters
for hint in roller_hints:
# ...existing code...
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
```
\ No newline at end of file
# Debugging Tile Language Programs
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
## Overview
A Tile Language program (hereafter referred to as a *program*) is transformed into a hardware-executable file through several stages:
1. The user writes a Tile Language program.
2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.).
3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file.
```{figure} ../_static/img/overview.png
:width: 300
:alt: Overview of the compilation process
:align: center
```
During this process, users may encounter roughly three categories of issues:
* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process).
* **Correctness issues**: The resulting executable runs, but produces incorrect results.
* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits.
This tutorial focuses on the first two issues—how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials.
Below, we take matrix multiplication (GEMM) as an example to demonstrate how to write and debug a Tile Language program.
## Matrix Multiplication Example
In **Tile Language**, you can use the **Tile Library** to implement matrix multiplication. Here's a complete example:
```python
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# ...existing code...
# 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. Compile the kernel into a torch function
# ...existing code...
```
## Debugging Generation Issues
TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code.
```{figure} ../_static/img/ir_transform_diagram.png
:width: 400
:alt: IR transformation diagram
:align: center
```
When the code fails to generate (for instance, a compilation error occurs), you do **not** necessarily need to jump directly into C++ passes to debug. Instead, you can first inspect the intermediate representations (IR) in Python by printing them.
For example, consider a case where a simple `T.copy` in 1D causes the lowering process to fail. The snippet below illustrates a simplified version of the problem (based on community Issue #35):
```python
@T.prim_func
def main(Q: T.Tensor(shape_q, dtype)):
# ...existing code...
```
The TileLang lower process might yield an error such as:
```text
File "/root/TileLang/src/target/codegen_cuda.cc", line 1257
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
```
This indicates that somewhere during code generation, an unsupported vectorization pattern was introduced (a ramp of 8 lanes). Before diving into the underlying C++ code, it is helpful to print the IR right before code generation. For instance:
```python
device_mod = tir.transform.Filter(is_device_call)(mod)
# ...existing code...
```
## Debugging Correctness Issues
Sometimes, the kernel compiles and runs but produces incorrect results. In such cases, there are two main strategies to help debug:
1. **Use post-processing callbacks to inspect or modify the generated CUDA code.**
2. **Use the built-in `T.print` debugging primitive to inspect values at runtime.**
### Post-Processing Callbacks for Generated Source
After code generation (in the codegen pass), TileLang calls a callback function (if registered) to allow post-processing of the generated source code. In `src/target/rt_mod_cuda.cc`:
```cpp
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
```
Hence, by registering a Python function named `tilelang_callback_cuda_postproc`, you can intercept the final CUDA code string. For example:
```python
import tilelang
import tilelang.language as T
from tilelang import tvm
from tilelang.engine.callback import register_cuda_postproc_callback
@register_cuda_postproc_callback
def tilelang_callback_cuda_postproc(code, _):
print(code) # print the final CUDA code
code = "// modified by tilelang_callback_cuda_postproc\n" + code
return code
kernel = tilelang.compile(matmul, target="cuda")
kernel_source = kernel.get_kernel_source()
print(kernel_source)
'''
// modified by tilelang_callback_cuda_postproc
#include "cuda_runtime.h"
...
'''
```
### Runtime Debug Prints with `T.print`
TileLang provides a built-in debugging primitive called `T.print` for printing within kernels. Be mindful of concurrency and thread synchronization when using it in GPU code. Below are some examples showing how to print buffers, variables, and other data inside TileLang programs.
1. **Printing an Entire Buffer**
```python
def debug_print_buffer(M=16, N=16):
# ...existing code...
```
2. **Conditional Printing**
```python
def debug_print_buffer_conditional(M=16, N=16):
# ...existing code...
```
3. **Printing Thread Indices or Scalar Values**
```python
def debug_print_value_conditional(M=16, N=16):
# ...existing code...
```
4. **Printing Fragment (Register File) Contents**
```python
def debug_print_register_files(M=16, N=16):
# ...existing code...
```
5. **Adding a Message Prefix**
```python
def debug_print_msg(M=16, N=16):
# ...existing code...
```
The output messages will include something like:
```text
msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0
```
## Conclusion
By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs.
For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents.
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
import numpy as np
import time
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float()
return output, lse
def get_fwd_configs():
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8, 9, 10]
qk_coalesced_width = [8]
v_coalesced_width = [4]
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
return valid_configs
@tilelang.autotune(configs=get_fwd_configs(), cache_input_tensors=True)
@tilelang.jit(out_idx=[3, 4])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
panel_size: int,
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
num_q_blocks = T.ceildiv(seq_len, block_M)
bx_loop_var = T.alloc_var("int32")
bx_loop_var = b_split
with T.While(bx_loop_var < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx_loop_var
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
loop_end_k = (
T.ceildiv(q_block_offset +
block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
k_pack=k_pack,
policy=GemmWarpPolicy.FullRow,
)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = acc_s[i, j] * scale
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
if m_prev[i] == -T.infinity(accum_dtype):
scale_factor[i] = 0.0
else:
scale_factor[i] = T.exp(m_prev[i] - m_i[i])
l_i[i] *= scale_factor[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]
for i, j in T.Parallel(block_M, block_N):
if acc_s[i, j] == -T.infinity(acc_s.dtype):
acc_s[i, j] = 0.0
else:
acc_s[i, j] = T.exp(acc_s[i, j] - m_i[i])
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
for i in T.Parallel(block_M):
if q_block_offset + i < seq_len:
lse_val = T.if_then_else(l_i[i] > 0,
T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
LSE[bz, by, q_block_offset + i] = lse_val
bx_loop_var = current_bx + num_split_q
return main
def get_bwd_configs():
block_M = [16, 32, 64, 128, 256]
block_N = [16, 32, 64, 128, 256]
threads = [64, 128, 256, 512, 1024]
num_stages = [0, 1, 2]
enable_rasterization = [True]
panel_size = [7, 8, 9, 10]
configs = []
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads,
enable_rasterization, panel_size):
configs.append({
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
})
return configs
@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True)
@tilelang.jit
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int,
num_stages: int, threads: int, enable_rasterization: bool, panel_size: int):
sm_scale = (1.0 / dim)**0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd_kernel(Q: T.Tensor(q_shape,
dtype), K: T.Tensor(kv_shape,
dtype), V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len],
accum_dtype),
Delta: T.Tensor([batch, heads, seq_len],
accum_dtype), dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
T.use_swizzle(panel_size, enable=enable_rasterization)
K_shared = T.alloc_shared([block_M, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
q_shared = T.alloc_shared([block_N, dim], dtype)
do_shared = T.alloc_shared([block_N, dim], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta_shared = T.alloc_shared([block_N], accum_dtype)
ds_shared = T.alloc_shared([block_M, block_N], dtype)
p_cast = T.alloc_fragment([block_M, block_N], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
P_acc = T.alloc_fragment([block_M, block_N], accum_dtype)
dP = T.alloc_fragment([block_M, block_N], accum_dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared)
T.clear(qkT)
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j,
P_acc[i, j], 0.0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared)
T.clear(dP)
T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(P_acc, p_cast)
T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared)
for i, j in T.Parallel(block_M, block_N):
p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale
T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(p_cast, ds_shared)
T.clear(dq)
T.gemm(ds_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
for i, j in T.Parallel(block_M, dim):
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j])
return flash_bwd_kernel
@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 64
@T.prim_func
def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy(
dQ_in[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
)
return flash_bwd_post
def debug_tensor_comparison(tensor1, tensor2, name, rtol=1e-3, atol=1e-3):
print(f"\n=== {name} Comparison ===")
print(f"Shape: {tensor1.shape} vs {tensor2.shape}")
print(f"Data type: {tensor1.dtype} vs {tensor2.dtype}")
print(f"Device: {tensor1.device} vs {tensor2.device}")
diff = torch.abs(tensor1 - tensor2)
max_diff = diff.max().item()
mean_diff = diff.mean().item()
std_diff = diff.std().item()
print(f"Max difference: {max_diff:.6f}")
print(f"Mean difference: {mean_diff:.6f}")
print(f"Difference std: {std_diff:.6f}")
if max_diff > atol:
max_idx = torch.argmax(diff)
max_idx = np.unravel_index(max_idx.cpu().numpy(), tensor1.shape)
print(f"Max difference position: {max_idx}")
print(f"Value1: {tensor1[max_idx].item():.6f}, Value2: {tensor2[max_idx].item():.6f}")
nan_count1 = torch.isnan(tensor1).sum().item()
nan_count2 = torch.isnan(tensor2).sum().item()
inf_count1 = torch.isinf(tensor1).sum().item()
inf_count2 = torch.isinf(tensor2).sum().item()
print(f"NaN count: {nan_count1} vs {nan_count2}")
print(f"Inf count: {inf_count1} vs {inf_count2}")
relative_diff = diff / (torch.abs(tensor2) + 1e-8)
max_relative_diff = relative_diff.max().item()
mean_relative_diff = relative_diff.mean().item()
print(f"Max relative difference: {max_relative_diff:.6f}")
print(f"Mean relative difference: {mean_relative_diff:.6f}")
close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol)
print(f"Within tolerance (rtol={rtol}, atol={atol}): {close}")
return close, max_diff, mean_diff
def benchmark_function(func, *args, warmup=10, repeat=100):
for _ in range(warmup):
func(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
times = []
for _ in range(repeat):
start = time.time()
func(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.time()
times.append((end - start) * 1000)
return np.median(times)
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
torch.cuda.manual_seed(42)
print(
f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}"
)
flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 5 * flops_per_gemm
print(f"Total FLOPs: {total_flops / 1e12:.2f} TFlops")
q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype)
v = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype)
dO = torch.randn_like(q)
print("Starting autotuning for Fast FlashAttention-V2 Forward Pass...")
fwd_kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups)
if fwd_kernel is None or fwd_kernel.config is None:
print("Forward pass auto-tuning failed.")
return
print(f"Autotuning finished. Best Forward Configuration: {fwd_kernel.config}")
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = fwd_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("Forward pass is correct.")
o_tl, lse_tl = fwd_kernel(q, k, v)
bwd_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim)
delta_tl = bwd_prep(o_tl, dO)
print("\nStarting FlashAttention-V2 backward pass autotuning...")
bwd_kernel = flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups)
if bwd_kernel is None or bwd_kernel.config is None:
print("Backward pass autotuning failed.")
return
print(f"Autotuning completed. Best backward pass configuration: {bwd_kernel.config}")
dQ_accum = torch.zeros_like(q, dtype=torch.float32)
dK_tl = torch.zeros_like(k, dtype=torch.float32)
dV_tl = torch.zeros_like(v, dtype=torch.float32)
bwd_kernel(q, k, v, dO, lse_tl, delta_tl, dQ_accum, dK_tl, dV_tl)
post_kernel = flashattn_bwd_postprocess(batch, heads, seq_len, dim)
dQ_tl = post_kernel(dQ_accum)
q_ref = q.clone().detach().requires_grad_()
k_ref = k.clone().detach().requires_grad_()
v_ref = v.clone().detach().requires_grad_()
o_ref, _ = ref_program(q_ref, k_ref, v_ref, is_causal, groups)
o_ref.backward(dO)
print("Verifying backward pass correctness...")
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(
dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
if dq_close:
print("dQ is correct.")
else:
print("dQ mismatch detected.")
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(
dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
if dk_close:
print("dK is correct.")
else:
print("dK mismatch detected.")
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(
dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
if dv_close:
print("dV is correct.")
else:
print("dV mismatch detected.")
print("\n=== Performance Benchmarking ===")
def run_reference_fwd_bwd():
q_ref_bench = q.clone().detach().requires_grad_()
k_ref_bench = k.clone().detach().requires_grad_()
v_ref_bench = v.clone().detach().requires_grad_()
o_ref_bench, _ = ref_program(q_ref_bench, k_ref_bench, v_ref_bench, is_causal, groups)
o_ref_bench.backward(dO)
if torch.cuda.is_available():
torch.cuda.synchronize()
ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100)
print(
f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops"
)
def run_complete_fwd_bwd():
o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v)
delta_tl_bench = bwd_prep(o_tl_bench, dO)
dQ_bench = torch.zeros_like(q, dtype=torch.float32)
dK_bench = torch.zeros_like(k, dtype=torch.float32)
dV_bench = torch.zeros_like(v, dtype=torch.float32)
bwd_kernel(q, k, v, dO, lse_tl_bench, delta_tl_bench, dQ_bench, dK_bench, dV_bench)
post_kernel(dQ_bench)
if torch.cuda.is_available():
torch.cuda.synchronize()
tile_latency = benchmark_function(run_complete_fwd_bwd, warmup=10, repeat=100)
print(
f"Complete Flash Attention V2 Forward+Backward (Tile-lang): {tile_latency:.2f} ms | {total_flops / tile_latency * 1e-9:.2f} TFlops"
)
speedup = ref_latency / tile_latency
print(f"Speedup: {speedup:.2f}x")
print("Forward output: Passed")
print(f"dQ: {'Passed' if dq_close else 'Failed'} (Max diff: {dq_max_diff:.6f})")
print(f"dK: {'Passed' if dk_close else 'Failed'} (Max diff: {dk_max_diff:.6f})")
print(f"dV: {'Passed' if dv_close else 'Failed'} (Max diff: {dv_max_diff:.6f})")
if all([dq_close, dk_close, dv_close]):
print("All checks passed!")
else:
print("Some checks failed, may need further debugging.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8]
qk_coalesced_width = [8]
v_coalesced_width = [4]
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
panel_size: int,
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32")
bx = b_split
with T.While(bx < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Use register fragment for P instead of shared memory to reduce LDS usage
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
k_pack=k_pack,
policy=GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
l_i[i] *= sf
scale_factor[i] = sf
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
# Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
bx = current_bx + num_split_q
return main
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
print("Starting autotuning for FlashAttention-V2...")
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
print(f"Autotuning finished. Best Configuration: {kernel.config}")
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=100)
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100)
print(
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
# TVM IR Performance Analyzer
A performance analysis toolkit for TVM IR modules, Provides hardware-aware performance metrics including FLOPs, memory bandwidth utilization, and execution time estimation.
## Features
-**Operation Analysis**: Supports arbitrary operations expressed in TVM IR (including GEMM and convolution)
-**Memory Traffic Calculation**: Tracks global memory transfers
-**Architecture-aware Metrics**: Pre-configured with NVIDIA GPU architectures (Ampere, Ada Lovelace)
-**Performance Estimation**: Predicts execution time using roofline model
-**TVM Integration**: Works with TVM IRModule and PrimFunc
## Quick Start
### GEMM Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), "float16"),
B: T.Tensor((N, K), "float16"),
C: T.Tensor((M, N), "float")):
# ... (kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
### Convolution Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"),
kernel: T.Tensor((K, K, C, F), "float16"),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")):
# ... (convolution kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
## API Documentation
### `AnalysisResult` Class
```python
@dataclass(frozen=True)
class AnalysisResult:
total_flops: int # Total floating-point operations
total_global_bytes: int # Global memory traffic in bytes
estimated_time: float # Predicted execution time (seconds)
tflops: float # Achieved TFLOPS
bandwidth_GBps: float # Memory bandwidth utilization
```
### `Analyzer` Class Methods
#### `analysis(fn, device)`
* ​Parameters:
* fn: TVM IRModule or PrimFunc
* device: Device configuration object
* Returns: AnalysisResult
#### Supported Architectures
```python
# Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count)
ARCH_CONFIGS = {
"80": (128, 1.41, 2, 108), # A100
"86": (128, 1.70, 2, 84), # RTX 3080
"89": (128, 2.52, 2, 128) # RTX 4090
}
```
## Implementation Details
### Performance Model
Uses roofline model with two constraints:
1.**Compute Bound**: `Time = Total FLOPs / (SM Count × Cores/SM × Clock × FLOPs/Cycle)`
2.**Memory Bound**: `Time = Memory Bytes / (Bandwidth × Utilization)`
### IR Analysis Pass
1.**Traversal**: Walks through TVM IR using `ir_transform`
2.**Operation Detection**:
- Counts FLOPs for all compute operations
- Calculates memory traffic for all memory operations
3.**Loop Handling**:
- Tracks nested loops for operation scaling
- Accounts for block/grid dimensions
## Key Metrics Calculation
| Metric | Formula |
|-------------------------|-----------------------------------------|
| FLOPs per GEMM | `2 × M × N × K` |
| Memory Traffic per Copy | `elements × dtype_size × loop_product` |
| Achieved TFLOPS | `total_flops / estimated_time / 1e12` |
| Memory Bandwidth | `total_global_bytes / estimated_time` |
## Limitations
1. Requires memory operations to be properly annotated in the IR
2. Assumes perfect memory coalescing and no bank conflicts
## Supported Operations
Any operation expressed in TVM IR
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout
import torch
N = 64
C = 256
H = 512
W = 512
F = 512
K = 3
S = 1
D = 1
P = 1
def check_hopper():
# if not torch.cuda.is_available():
# return None
# props = torch.cuda.get_device_properties(0)
# compute_capability = props.major, props.minor
# return compute_capability == (9, 0)
return False
def kernel(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
@T.prim_func
def conv(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return conv
def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
if __name__ == "__main__":
main()
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
import torch
M = N = K = 1024
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return matmul
def main():
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}")
print(f"Expected FLOPs: {2 * M * N * K}")
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_analyze
import example_conv_analyze
def test_example_gemm_analyze():
example_gemm_analyze.main()
def test_example_conv_analyze():
example_conv_analyze.main()
if __name__ == "__main__":
tilelang.testing.main()
# Attention Sink
We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py).
## Algorithm
### Forward
The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage.
### Backward
Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by:
$$
dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q}
$$
where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row).
## Benchmark of forward process
### Benchmark Environment
- **Hardware**: NVIDIA H800
- **CUDA version**: 12.9
- **Triton Version**: 3.4.0
### Results
- dtype=bfloat16
- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B)
- Full attention is adopted.
| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup |
|---------|---------|---------------|----------------------|---------|
| 2048 | 64 | 232.98 | **281.89** | 1.21x |
| 2048 | 128 | 321.55 | **417.98** | 1.30x |
| | | | | |
| 4096 | 64 | 280.70 | **349.47** | 1.25x |
| 4096 | 128 | 369.61 | **497.13** | 1.35x |
| | | | | |
| 8192 | 64 | 299.04 | **385.56** | 1.29x |
| 8192 | 128 | 399.39 | **507.93** | 1.27x |
| | | | | |
| 16384 | 64 | 309.46 | **400.62** | 1.29x |
| 16384 | 128 | 418.99 | **549.11** | 1.31x |
> The backward performance will be further optimized in the future.
\ No newline at end of file
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
BLOCK_N = 64
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
# Adapted from tilelang/examples/flash_attention/example_gqa_bwd.py
import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
from typing import Optional
def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
return 64, 32, 1, 128
elif sm_version == 90:
return 128, 32, 2, 256
else:
raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(
batch,
heads,
seq_len,
dim,
groups=1,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16"):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
Output: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size,
0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch,
heads,
seq_len,
dim,
groups,
window_size=None,
sm_scale=None,
dtype="float16"): # None for full attention
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
block_M, block_N, num_stages, threads = get_bwd_configs()
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
dO: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(kv_shape, accum_dtype), # type: ignore
dV: T.Tensor(kv_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N):
if window_size is not None:
qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and
by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0)
else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared)
return flash_bwd
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len]
@T.prim_func
def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], dtype)
sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 -
lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block])
return flash_bwd_dsink
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sinks, window_size, groups):
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size
ctx.groups = groups
return o
@staticmethod
def backward(ctx, do):
q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
groups = ctx.groups
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype)
q_shape = [BATCH, H, N_CTX, D_HEAD]
head_kv = H // groups
kv_shape = [BATCH, head_kv, N_CTX, D_HEAD]
dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None, None
attention = _attention.apply
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1,
H: int = 8,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: int | None = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min(
window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul
Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_())
K = torch.randn(
BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
V = torch.randn_like(K).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size, groups)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks
rtol, atol = {
"float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}'
assert torch.allclose(
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}'
assert torch.allclose(
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}'
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅")
# Only benchmark backward here
def torch_bwd():
O_ref.backward(dO, retain_graph=True)
def tl_bwd():
O.backward(dO, retain_graph=True)
latency = do_bench(torch_bwd, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(tl_bwd, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=64, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=4096, help='Context size')
parser.add_argument('--d_head', type=int, default=128, help='Head dimension')
parser.add_argument('--groups', type=int, default=8, help='Groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype)
# Modified from tilelang/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=500,
rep=100,
)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups=1,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(
start[0],
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
groups,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
# Adapted from tilelang/examples/flash_attention/example_mha_bwd_bhsd.py
import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
from typing import Optional
def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
return 64, 32, 1, 128
elif sm_version == 90:
return 128, 32, 2, 256
else:
raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(
batch,
heads,
seq_len,
dim,
window_size=None, # None for full attention,
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16"):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size,
0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(
batch,
heads,
seq_len,
dim,
window_size=None, # None for full attention
sm_scale=None,
dtype: str = "float16",
):
block_M, block_N, num_stages, threads = get_bwd_configs()
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = "float"
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N):
if window_size is not None:
qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and
by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0)
else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq)
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :])
return flash_bwd
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len]
@T.prim_func
def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], accum_dtype)
sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 -
lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block])
return flash_bwd_dsink
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size
return o
@staticmethod
def backward(ctx, do):
q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype)
shape = [BATCH, H, N_CTX, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.empty(shape, dtype=q.dtype, device=q.device)
dv = torch.empty(shape, dtype=q.dtype, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None
attention = _attention.apply
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1)
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1,
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: int | None = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min(
window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul
Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_())
K = torch.randn_like(Q).requires_grad_()
V = torch.randn_like(Q).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks
rtol, atol = {
"float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}'
assert torch.allclose(
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}'
assert torch.allclose(
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}'
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅")
# Only benchmark backward here
def torch_bwd():
O_ref.backward(dO, retain_graph=True)
def tl_bwd():
O.backward(dO, retain_graph=True)
latency = do_bench(torch_bwd, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(tl_bwd, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=64, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=4096, help='Context size')
parser.add_argument('--d_head', type=int, default=128, help='Head dimension')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment