Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
# Debugging Tile Language Programs
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/LeiWang1999">Lei Wang</a>
</div>
## Overview
A Tile Language program (hereafter referred to as a *program*) is transformed into a hardware-executable file through several stages:
1. The user writes a Tile Language program.
2. The program undergoes multiple *Passes* for transformation and optimization (the *lower* stage, see `tilelang/engine/lower.py`), finally producing an intermediate representation (e.g., LLVM or C for CPU, CUDA for NVIDIA GPUs, etc.).
3. The generated code is compiled by the respective compiler (e.g., nvcc) into a hardware-executable file.
```{figure} ../_static/img/overview.png
:width: 300
:alt: Overview of the compilation process
:align: center
```
During this process, users may encounter roughly three categories of issues:
* **Generation issues**: The Tile Language program fails to generate a valid hardware-executable file (i.e., errors during the lowering process).
* **Correctness issues**: The resulting executable runs, but produces incorrect results.
* **Performance issues**: The executable runs with performance significantly below the expected theoretical hardware limits.
This tutorial focuses on the first two issues—how to debug generation and correctness problems. Performance tuning often requires using vendor-provided profiling tools (e.g., **Nsight Compute**, **rocProf**, etc.) for further hardware-level analysis, which we will address in future materials.
Below, we take matrix multiplication (GEMM) as an example to demonstrate how to write and debug a Tile Language program.
## Matrix Multiplication Example
In **Tile Language**, you can use the **Tile Library** to implement matrix multiplication. Here's a complete example:
```python
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# ...existing code...
# 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. Compile the kernel into a torch function
# ...existing code...
```
## Debugging Generation Issues
TileLang essentially performs *progressive lowering*. For example, a `T.copy` may first be expanded into `T.Parallel` (see the pass `LowerTileOP`), which is then expanded again, eventually resulting in lower-level statements that can be translated to CUDA C code.
```{figure} ../_static/img/ir_transform_diagram.png
:width: 400
:alt: IR transformation diagram
:align: center
```
When the code fails to generate (for instance, a compilation error occurs), you do **not** necessarily need to jump directly into C++ passes to debug. Instead, you can first inspect the intermediate representations (IR) in Python by printing them.
For example, consider a case where a simple `T.copy` in 1D causes the lowering process to fail. The snippet below illustrates a simplified version of the problem (based on community Issue #35):
```python
@T.prim_func
def main(Q: T.Tensor(shape_q, dtype)):
# ...existing code...
```
The TileLang lower process might yield an error such as:
```text
File "/root/TileLang/src/target/codegen_cuda.cc", line 1257
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
```
This indicates that somewhere during code generation, an unsupported vectorization pattern was introduced (a ramp of 8 lanes). Before diving into the underlying C++ code, it is helpful to print the IR right before code generation. For instance:
```python
device_mod = tir.transform.Filter(is_device_call)(mod)
# ...existing code...
```
## Debugging Correctness Issues
Sometimes, the kernel compiles and runs but produces incorrect results. In such cases, there are two main strategies to help debug:
1. **Use post-processing callbacks to inspect or modify the generated CUDA code.**
2. **Use the built-in `T.print` debugging primitive to inspect values at runtime.**
### Post-Processing Callbacks for Generated Source
After code generation (in the codegen pass), TileLang calls a callback function (if registered) to allow post-processing of the generated source code. In `src/target/rt_mod_cuda.cc`:
```cpp
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
```
Hence, by registering a Python function named `tilelang_callback_cuda_postproc`, you can intercept the final CUDA code string. For example:
```python
import tilelang
import tilelang.language as T
from tilelang import tvm
from tilelang.engine.callback import register_cuda_postproc_callback
@register_cuda_postproc_callback
def tilelang_callback_cuda_postproc(code, _):
print(code) # print the final CUDA code
code = "// modified by tilelang_callback_cuda_postproc\n" + code
return code
kernel = tilelang.compile(matmul, target="cuda")
kernel_source = kernel.get_kernel_source()
print(kernel_source)
'''
// modified by tilelang_callback_cuda_postproc
#include "cuda_runtime.h"
...
'''
```
### Runtime Debug Prints with `T.print`
TileLang provides a built-in debugging primitive called `T.print` for printing within kernels. Be mindful of concurrency and thread synchronization when using it in GPU code. Below are some examples showing how to print buffers, variables, and other data inside TileLang programs.
1. **Printing an Entire Buffer**
```python
def debug_print_buffer(M=16, N=16):
# ...existing code...
```
2. **Conditional Printing**
```python
def debug_print_buffer_conditional(M=16, N=16):
# ...existing code...
```
3. **Printing Thread Indices or Scalar Values**
```python
def debug_print_value_conditional(M=16, N=16):
# ...existing code...
```
4. **Printing Fragment (Register File) Contents**
```python
def debug_print_register_files(M=16, N=16):
# ...existing code...
```
5. **Adding a Message Prefix**
```python
def debug_print_msg(M=16, N=16):
# ...existing code...
```
The output messages will include something like:
```text
msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0
```
### Visual Layout Inference For TileLang
The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations.
When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates:
1. **Textual output**: A human-readable description of the layout mapping
2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping
The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation.
When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats:
- "txt": Text output only (same as default)
- "all": Generates all formats (TXT, PDF, PNG, SVG)
- "png": Generate PNG format only
- "pdf": Generate PDF format only
- "svg": Generate SVG format only
- "txt,svg": Generate multiple formats (comma-separated) in addition to text output
The output messages of "txt" will include something like:
```
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
```
## Conclusion
By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs.
For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents.
Logging in Tilelang/TVM
===================================================
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/SiriusNEO">SiriusNEO</a>
</div>
## TVM Logging Overview
Tilelang currently utilizes the logging system from TVM. The implementation can be found in:
- [include/tvm/runtime/logging.h](https://github.com/apache/tvm/blob/main/include/tvm/runtime/logging.h): Macro definitions
- [src/runtime/logging.cc](https://github.com/apache/tvm/blob/main/src/runtime/logging.cc): Logging logic implementation
The design style is inspired by [Google's glog](https://google.github.io/glog/stable/).
## Logging Categories
There are three primary macro types:
```c++
LOG(INFO) << "aaa";
DLOG(INFO) << "aaa";
VLOG(1) << "aaa";
```
- **LOG**: Standard logging preserved in code for displaying necessary information at different levels during runtime. Most Tilelang C++ error reporting is implemented via `LOG(FATAL) << "error msg"`.
- **DLOG**: Debug logging for developer debugging output. DLOG is controlled at build time by the TVM_LOG_DEBUG environment variable and is **eliminated in Release builds through dead code elimination**.
- The key difference between LOG(DEBUG) and DLOG is this build-time elimination. We recommend using DLOG over LOG(DEBUG), as the latter has overlapping functionality and gets compiled into the release runtime.
- **VLOG**: [Verbose logging](https://google.github.io/glog/stable/logging/#verbose-logging), primarily for debugging. Its main feature is customizable verbosity levels. For example, VLOG(n) where n can be 1, 2, 3, 4, 5, or 6, enabling complex tracing requirements. In contrast, LOG and DLOG typically use predefined verbose levels like INFO and DEBUG.
- In practical Tilelang development, VLOG is used less frequently.
- TVM's VLOG is implemented using DLOG, thus inheriting DLOG's characteristics.
Additional useful macros include various **CHECK** variants:
```c++
CHECK(cond) << "error msg";
DCHECK(cond) << "error msg";
ICHECK(cond) << "error msg";
```
The implementation routes errors to LogFatal:
```c++
#define CHECK(x) \
if (!(x)) \
::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \
<< "Check failed: (" #x << ") is false: "
```
- **DCHECK**: Debug mode CHECK, only compiled in debug builds
- **ICHECK**: Internal Check that should exist in Release builds. When ICHECK fails, the entire system should report an error.
## Logging Verbose Levels
TVM defines 5 levels for LOG and DLOG (adding DEBUG compared to glog):
```c++
#define TVM_LOG_LEVEL_DEBUG 0
#define TVM_LOG_LEVEL_INFO 1
#define TVM_LOG_LEVEL_WARNING 2
#define TVM_LOG_LEVEL_ERROR 3
#define TVM_LOG_LEVEL_FATAL 4
```
## Using Logging in TileLang Development
### Guidelines
For temporary debugging output in your code, there are no restrictions (you can even use std::cout). Just remember to remove it before submitting a PR.
For meaningful logging that should remain in the Tilelang codebase:
- Critical correctness checks: Use ICHECK with sufficient error messages to facilitate debugging when issues arise.
- Complex Pass debugging: For passes requiring intermediate output that may need future review (e.g., LayoutInference), use DLOG.
- General INFO/WARNING messages: Use standard LOG.
### Enabling Log Output in Tilelang
To specify current log level at runtime, we need to set the environment variable `TVM_LOG_LEVEL`. An example usage is:
```c++
TVM_LOG_DEBUG=1 python3 code.py
```
which enables all DEBUG/INFO (level <= 1) logs for all files.
#### Detailed Rules for TVM_LOG_DEBUG Specification
The parsing logic is in `logging.cc`. Reference: [HyperAI Zhihu Article](https://zhuanlan.zhihu.com/p/1933106843468665163).
Launch Python with `TVM_LOG_DEBUG=<spec>`, where `<spec>` is a comma-separated list of level assignments in the form `<file_name>=<level>`. Important notes:
- The special filename DEFAULT sets the LOG level for all files.
- `<level>` can be set to -1 to disable LOG for that file.
- `<file_name>` is the C++ source filename (e.g., .cc, not .h) relative to the `src/` directory in the TVM repository. The `src/` prefix is optional when specifying file paths.
### Enabling Debug Mode
To enable DLOG/DCHECK, developers need to first build Tilelang in Debug mode:
```bash
cmake .. -DCMAKE_BUILD_TYPE=Debug -DUSE_CUDA=ON
```
Tilelang's CMake logic automatically adds the `TVM_LOG_DEBUG` macro, compiling all DLOG statements:
```cmake
target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG")
```
Then you also need to specify the runtime environment variables. For example, to use `DLOG(INFO) << "xxx"` for debugging, run your code with INFO level (1): `TVM_LOG_DEBUG=1`.
:::{note}
**Important**: There are two TVM_LOG_DEBUG variables. (1) Compile-time macro: Determines whether debug content (like DLOG) is compiled into the .so file. Referenced in C++ source via #ifdef TVM_LOG_DEBUG. This is automatically enabled when using Debug build mode in CMake. (2) Runtime environment variable: Controls logging level at runtime. TVM provides a specification for this variable, allowing control over per-file logging levels.
These two should ideally have different names, but TVM uses the same name for both, which can cause confusion.
:::
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
import numpy as np
import time
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float()
return output, lse
def get_fwd_configs():
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8, 9, 10]
qk_coalesced_width = [8]
v_coalesced_width = [4]
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
@tilelang.autotune(configs=get_fwd_configs(), cache_input_tensors=True)
@tilelang.jit(out_idx=[3, 4])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
panel_size: int,
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
num_q_blocks = T.ceildiv(seq_len, block_M)
bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split
with T.While(bx_loop_var < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx_loop_var
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
k_pack=k_pack,
policy=GemmWarpPolicy.FullRow,
)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = acc_s[i, j] * scale
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
if m_prev[i] == -T.infinity(accum_dtype):
scale_factor[i] = 0.0
else:
scale_factor[i] = T.exp(m_prev[i] - m_i[i])
l_i[i] *= scale_factor[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]
for i, j in T.Parallel(block_M, block_N):
if acc_s[i, j] == -T.infinity(acc_s.dtype):
acc_s[i, j] = 0.0
else:
acc_s[i, j] = T.exp(acc_s[i, j] - m_i[i])
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
for i in T.Parallel(block_M):
if q_block_offset + i < seq_len:
lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
LSE[bz, by, q_block_offset + i] = lse_val
bx_loop_var = current_bx + num_split_q
return main
def get_bwd_configs():
block_M = [16, 32, 64, 128, 256]
block_N = [16, 32, 64, 128, 256]
threads = [64, 128, 256, 512, 1024]
num_stages = [0, 1, 2]
enable_rasterization = [True]
panel_size = [7, 8, 9, 10]
configs = []
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size):
configs.append(
{
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
}
)
return configs
@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True)
@tilelang.jit
def flashattn_bwd(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_stages: int,
threads: int,
enable_rasterization: bool,
panel_size: int,
):
sm_scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype),
lse: T.Tensor([batch, heads, seq_len], accum_dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype),
dV: T.Tensor(kv_shape, accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
T.use_swizzle(panel_size, enable=enable_rasterization)
K_shared = T.alloc_shared([block_M, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
q_shared = T.alloc_shared([block_N, dim], dtype)
do_shared = T.alloc_shared([block_N, dim], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta_shared = T.alloc_shared([block_N], accum_dtype)
ds_shared = T.alloc_shared([block_M, block_N], dtype)
p_cast = T.alloc_fragment([block_M, block_N], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
P_acc = T.alloc_fragment([block_M, block_N], accum_dtype)
dP = T.alloc_fragment([block_M, block_N], accum_dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared)
T.clear(qkT)
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared)
T.clear(dP)
T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(P_acc, p_cast)
T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared)
for i, j in T.Parallel(block_M, block_N):
p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale
T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(p_cast, ds_shared)
T.clear(dq)
T.gemm(ds_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
for i, j in T.Parallel(block_M, dim):
T.atomic_add(dV[bz, by * block_M + i, bx // groups, j], dv[i, j])
T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk[i, j])
return flash_bwd_kernel
@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64
@T.prim_func
def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy(
dQ_in[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
def debug_tensor_comparison(tensor1, tensor2, name, rtol=1e-3, atol=1e-3):
print(f"\n=== {name} Comparison ===")
print(f"Shape: {tensor1.shape} vs {tensor2.shape}")
print(f"Data type: {tensor1.dtype} vs {tensor2.dtype}")
print(f"Device: {tensor1.device} vs {tensor2.device}")
diff = torch.abs(tensor1 - tensor2)
max_diff = diff.max().item()
mean_diff = diff.mean().item()
std_diff = diff.std().item()
print(f"Max difference: {max_diff:.6f}")
print(f"Mean difference: {mean_diff:.6f}")
print(f"Difference std: {std_diff:.6f}")
if max_diff > atol:
max_idx = torch.argmax(diff)
max_idx = np.unravel_index(max_idx.cpu().numpy(), tensor1.shape)
print(f"Max difference position: {max_idx}")
print(f"Value1: {tensor1[max_idx].item():.6f}, Value2: {tensor2[max_idx].item():.6f}")
nan_count1 = torch.isnan(tensor1).sum().item()
nan_count2 = torch.isnan(tensor2).sum().item()
inf_count1 = torch.isinf(tensor1).sum().item()
inf_count2 = torch.isinf(tensor2).sum().item()
print(f"NaN count: {nan_count1} vs {nan_count2}")
print(f"Inf count: {inf_count1} vs {inf_count2}")
relative_diff = diff / (torch.abs(tensor2) + 1e-8)
max_relative_diff = relative_diff.max().item()
mean_relative_diff = relative_diff.mean().item()
print(f"Max relative difference: {max_relative_diff:.6f}")
print(f"Mean relative difference: {mean_relative_diff:.6f}")
close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol)
print(f"Within tolerance (rtol={rtol}, atol={atol}): {close}")
return close, max_diff, mean_diff
def benchmark_function(func, *args, warmup=10, repeat=100):
for _ in range(warmup):
func(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
times = []
for _ in range(repeat):
start = time.time()
func(*args)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.time()
times.append((end - start) * 1000)
return np.median(times)
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
torch.cuda.manual_seed(42)
print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}")
flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 5 * flops_per_gemm
print(f"Total FLOPs: {total_flops / 1e12:.2f} TFlops")
q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype)
k = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype)
v = torch.randn(batch, seq_len, heads // groups, dim, device=device, dtype=dtype)
dO = torch.randn_like(q)
print("Starting autotuning for Fast FlashAttention-V2 Forward Pass...")
fwd_kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups)
if fwd_kernel is None or fwd_kernel.config is None:
print("Forward pass auto-tuning failed.")
return
print(f"Autotuning finished. Best Forward Configuration: {fwd_kernel.config}")
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = fwd_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("Forward pass is correct.")
o_tl, lse_tl = fwd_kernel(q, k, v)
bwd_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim)
delta_tl = bwd_prep(o_tl, dO)
print("\nStarting FlashAttention-V2 backward pass autotuning...")
bwd_kernel = flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups)
if bwd_kernel is None or bwd_kernel.config is None:
print("Backward pass autotuning failed.")
return
print(f"Autotuning completed. Best backward pass configuration: {bwd_kernel.config}")
dQ_accum = torch.zeros_like(q, dtype=torch.float32)
dK_tl = torch.zeros_like(k, dtype=torch.float32)
dV_tl = torch.zeros_like(v, dtype=torch.float32)
bwd_kernel(q, k, v, dO, lse_tl, delta_tl, dQ_accum, dK_tl, dV_tl)
post_kernel = flashattn_bwd_postprocess(batch, heads, seq_len, dim)
dQ_tl = post_kernel(dQ_accum)
q_ref = q.clone().detach().requires_grad_()
k_ref = k.clone().detach().requires_grad_()
v_ref = v.clone().detach().requires_grad_()
o_ref, _ = ref_program(q_ref, k_ref, v_ref, is_causal, groups)
o_ref.backward(dO)
print("Verifying backward pass correctness...")
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
if dq_close:
print("dQ is correct.")
else:
print("dQ mismatch detected.")
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
if dk_close:
print("dK is correct.")
else:
print("dK mismatch detected.")
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
if dv_close:
print("dV is correct.")
else:
print("dV mismatch detected.")
print("\n=== Performance Benchmarking ===")
def run_reference_fwd_bwd():
q_ref_bench = q.clone().detach().requires_grad_()
k_ref_bench = k.clone().detach().requires_grad_()
v_ref_bench = v.clone().detach().requires_grad_()
o_ref_bench, _ = ref_program(q_ref_bench, k_ref_bench, v_ref_bench, is_causal, groups)
o_ref_bench.backward(dO)
if torch.cuda.is_available():
torch.cuda.synchronize()
ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100)
print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops")
def run_complete_fwd_bwd():
o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v)
delta_tl_bench = bwd_prep(o_tl_bench, dO)
dQ_bench = torch.zeros_like(q, dtype=torch.float32)
dK_bench = torch.zeros_like(k, dtype=torch.float32)
dV_bench = torch.zeros_like(v, dtype=torch.float32)
bwd_kernel(q, k, v, dO, lse_tl_bench, delta_tl_bench, dQ_bench, dK_bench, dV_bench)
post_kernel(dQ_bench)
if torch.cuda.is_available():
torch.cuda.synchronize()
tile_latency = benchmark_function(run_complete_fwd_bwd, warmup=10, repeat=100)
print(
f"Complete Flash Attention V2 Forward+Backward (Tile-lang): {tile_latency:.2f} ms | {total_flops / tile_latency * 1e-9:.2f} TFlops"
)
speedup = ref_latency / tile_latency
print(f"Speedup: {speedup:.2f}x")
print("Forward output: Passed")
print(f"dQ: {'Passed' if dq_close else 'Failed'} (Max diff: {dq_max_diff:.6f})")
print(f"dK: {'Passed' if dk_close else 'Failed'} (Max diff: {dk_max_diff:.6f})")
print(f"dV: {'Passed' if dv_close else 'Failed'} (Max diff: {dv_max_diff:.6f})")
if all([dq_close, dk_close, dv_close]):
print("All checks passed!")
else:
print("Some checks failed, may need further debugging.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=1024, help="sequence length")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
tensors.append(tensor)
else:
tensors.append(param)
return tensors
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8]
qk_coalesced_width = [8]
v_coalesced_width = [4]
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
panel_size: int,
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
bz = byz_combined // heads
by = byz_combined % heads
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var(T.int32)
bx = b_split
with T.While(bx < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)
current_bx = bx
q_block_offset = current_bx * block_M
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Use register fragment for P instead of shared memory to reduce LDS usage
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
k_pack=k_pack,
policy=GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
l_i[i] *= sf
scale_factor[i] = sf
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale)
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]
# Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow)
l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l
for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]
bx = current_bx + num_split_q
return main
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
print("Starting autotuning for FlashAttention-V2...")
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
print(f"Autotuning finished. Best Configuration: {kernel.config}")
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=100)
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100)
print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
# TVM IR Performance Analyzer
A performance analysis toolkit for TVM IR modules, Provides hardware-aware performance metrics including FLOPs, memory bandwidth utilization, and execution time estimation.
## Features
-**Operation Analysis**: Supports arbitrary operations expressed in TVM IR (including GEMM and convolution)
-**Memory Traffic Calculation**: Tracks global memory transfers
-**Architecture-aware Metrics**: Pre-configured with NVIDIA GPU architectures (Ampere, Ada Lovelace)
-**Performance Estimation**: Predicts execution time using roofline model
-**TVM Integration**: Works with TVM IRModule and PrimFunc
## Quick Start
### GEMM Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), T.float16),
B: T.Tensor((N, K), T.float16),
C: T.Tensor((M, N), T.float)):
# ... (kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
### Convolution Analysis Example
```python
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), T.float16),
kernel: T.Tensor((K, K, C, F), T.float16),
out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)):
# ... (convolution kernel definition)
return main
cuda_device = CUDA("cuda")
result = Analyzer.analysis(kernel(), cuda_device)
print(result)
```
## API Documentation
### `AnalysisResult` Class
```python
@dataclass(frozen=True)
class AnalysisResult:
total_flops: int # Total floating-point operations
total_global_bytes: int # Global memory traffic in bytes
estimated_time: float # Predicted execution time (seconds)
tflops: float # Achieved TFLOPS
bandwidth_GBps: float # Memory bandwidth utilization
```
### `Analyzer` Class Methods
#### `analysis(fn, device)`
* ​Parameters:
* fn: TVM IRModule or PrimFunc
* device: Device configuration object
* Returns: AnalysisResult
#### Supported Architectures
```python
# Extendable to custom hardware via: "compute_capability": (cores_per_SM, clock_GHz, flops_per_cycle, max_SM_count)
ARCH_CONFIGS = {
"80": (128, 1.41, 2, 108), # A100
"86": (128, 1.70, 2, 84), # RTX 3080
"89": (128, 2.52, 2, 128) # RTX 4090
}
```
## Implementation Details
### Performance Model
Uses roofline model with two constraints:
1.**Compute Bound**: `Time = Total FLOPs / (SM Count × Cores/SM × Clock × FLOPs/Cycle)`
2.**Memory Bound**: `Time = Memory Bytes / (Bandwidth × Utilization)`
### IR Analysis Pass
1.**Traversal**: Walks through TVM IR using `ir_transform`
2.**Operation Detection**:
- Counts FLOPs for all compute operations
- Calculates memory traffic for all memory operations
3.**Loop Handling**:
- Tracks nested loops for operation scaling
- Accounts for block/grid dimensions
## Key Metrics Calculation
| Metric | Formula |
|-------------------------|-----------------------------------------|
| FLOPs per GEMM | `2 × M × N × K` |
| Memory Traffic per Copy | `elements × dtype_size × loop_product` |
| Achieved TFLOPS | `total_flops / estimated_time / 1e12` |
| Memory Bandwidth | `total_global_bytes / estimated_time` |
## Limitations
1. Requires memory operations to be properly annotated in the IR
2. Assumes perfect memory coalescing and no bank conflicts
## Supported Operations
Any operation expressed in TVM IR
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout
import torch
N = 64
C = 256
H = 512
W = 512
F = 512
K = 3
S = 1
D = 1
P = 1
def check_hopper():
# if not torch.cuda.is_available():
# return None
# props = torch.cuda.get_device_properties(0)
# compute_capability = props.major, props.minor
# return compute_capability == (9, 0)
return False
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
def conv(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout(
{
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return conv
def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
if __name__ == "__main__":
main()
import tilelang.language as T
from tilelang.tools import Analyzer
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
import torch
M = N = K = 1024
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return matmul
def main():
my_func = kernel(128, 128, 32, 3, 128, True)
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}")
print(f"Expected FLOPs: {2 * M * N * K}")
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_analyze
import example_conv_analyze
def test_example_gemm_analyze():
example_gemm_analyze.main()
def test_example_conv_analyze():
example_conv_analyze.main()
if __name__ == "__main__":
tilelang.testing.main()
# Attention Sink
We compare with an optimized version of the official Triton implementation [here](https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py).
## Algorithm
### Forward
The only change from vanilla FlashAttention is that `sinks` should be taken into consideration in the softmax, which requires an extra rescaling at the epilogue stage.
### Backward
Based on detailed mathematical derivation, interestingly, the backward computation process of `dQ`, `dK`, `dv` is almost identical to that in vanilla FlashAttention, except for that the specific meanings of `lse` differ. We only need to compute `dsinks` additionally, which is given by:
$$
dsink_h=-\sum_{b}\sum_{q}P_{b, h, q}Delta_{b, h, q}
$$
where $P_{b, h, q}$ is the proportion of $sink_h$ in the softmax in the $b$-th block, $h$-th head and $q$-th query(row).
## Benchmark of forward process
### Benchmark Environment
- **Hardware**: NVIDIA H800
- **CUDA version**: 12.9
- **Triton Version**: 3.4.0
### Results
- dtype=bfloat16
- batch_size=1, heads=64, kv_heads=8 (the setting of GPT-OSS-120B)
- Full attention is adopted.
| SEQ_LEN | headdim | Triton TFLOPs | TileLang TFLOPs | Speedup |
|---------|---------|---------------|----------------------|---------|
| 2048 | 64 | 232.98 | **281.89** | 1.21x |
| 2048 | 128 | 321.55 | **417.98** | 1.30x |
| | | | | |
| 4096 | 64 | 280.70 | **349.47** | 1.25x |
| 4096 | 128 | 369.61 | **497.13** | 1.35x |
| | | | | |
| 8192 | 64 | 299.04 | **385.56** | 1.29x |
| 8192 | 128 | 399.39 | **507.93** | 1.27x |
| | | | | |
| 16384 | 64 | 309.46 | **400.62** | 1.29x |
| 16384 | 128 | 418.99 | **549.11** | 1.31x |
> The backward performance will be further optimized in the future.
\ No newline at end of file
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q,
)
return o
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
BLOCK_N = 64
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q,
)
return o
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
# Adapted from tilelang/examples/flash_attention/example_gqa_bwd.py
import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
from typing import Optional
def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
return 64, 32, 1, 128
elif sm_version == 90:
return 128, 32, 2, 256
else:
raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(
batch,
heads,
seq_len,
dim,
groups=1,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
Output: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = T.float32
block_M, block_N, num_stages, threads = get_bwd_configs()
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
dO: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(kv_shape, accum_dtype), # type: ignore
dV: T.Tensor(kv_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = (
T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
if window_size is not None
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N):
if window_size is not None:
qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
)
else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared)
return flash_bwd
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len]
@T.prim_func
def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], dtype)
sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
return flash_bwd_dsink
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sinks, window_size, groups):
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size
ctx.groups = groups
return o
@staticmethod
def backward(ctx, do):
q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
groups = ctx.groups
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, ctx.window_size, dtype=dtype)
q_shape = [BATCH, H, N_CTX, D_HEAD]
head_kv = H // groups
kv_shape = [BATCH, head_kv, N_CTX, D_HEAD]
dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(kv_shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None, None
attention = _attention.apply
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def main(
BATCH: int = 1,
H: int = 8,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: Optional[int] = None,
dtype: str = "float16",
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul
Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
V = torch.randn_like(K).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size, groups)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks
rtol, atol = {
T.float16: (1e-2, 1e-2),
T.bfloat16: (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
print("All checks passed for tilelang kernels.✅")
# Only benchmark backward here
def torch_bwd():
O_ref.backward(dO, retain_graph=True)
def tl_bwd():
O.backward(dO, retain_graph=True)
latency = do_bench(torch_bwd, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(tl_bwd, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument("--h", type=int, default=64, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
parser.add_argument("--groups", type=int, default=8, help="Groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype)
# Modified from tilelang/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=500,
rep=100,
)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups=1,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: Optional[int] = None,
dtype: T.dtype = T.float16,
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
# Adapted from tilelang/examples/flash_attention/example_mha_bwd_bhsd.py
import torch
import tilelang
from tilelang.profiler import do_bench
import tilelang.language as T
import argparse
from typing import Optional
def get_bwd_configs():
sm_major, sm_minor = torch.cuda.get_device_capability()
sm_version = sm_major * 10 + sm_minor
if sm_version == 80:
return 64, 32, 1, 128
elif sm_version == 90:
return 128, 32, 2, 256
else:
raise ValueError(f"Unsupported SM version: {sm_version}")
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(
batch,
heads,
seq_len,
dim,
window_size=None, # None for full attention,
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(
batch,
heads,
seq_len,
dim,
window_size=None, # None for full attention
sm_scale=None,
dtype: T.dtype = T.float16,
):
block_M, block_N, num_stages, threads = get_bwd_configs()
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = T.float32
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = (
T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
if window_size is not None
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N):
if window_size is not None:
qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
)
else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :])
return flash_bwd
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len]
@T.prim_func
def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
sink = T.alloc_local([1], dtype)
lse_fragment = T.alloc_fragment([block], accum_dtype)
delta_fragment = T.alloc_fragment([block], accum_dtype)
dsink_fragment = T.alloc_fragment([block], accum_dtype)
sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
return flash_bwd_dsink
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
ctx.window_size = window_size
return o
@staticmethod
def backward(ctx, do):
q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.window_size, dtype=dtype)
shape = [BATCH, H, N_CTX, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
dk = torch.empty(shape, dtype=q.dtype, device=q.device)
dv = torch.empty(shape, dtype=q.dtype, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
kernel_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype)
dsinks = kernel_dsink(sinks, delta, lse).sum(0).sum(1)
return dq, dk, dv, dsinks, None
attention = _attention.apply
# Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1)
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul
Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
K = torch.randn_like(Q).requires_grad_()
V = torch.randn_like(Q).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, sinks, window_size)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
dsinks, sinks.grad = sinks.grad.clone(), None
O_ref = ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
dsinks_ref, sinks.grad = sinks.grad.clone(), None
# Checks
rtol, atol = {
T.float16: (1e-2, 1e-2),
T.bfloat16: (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
print("All checks passed for tilelang kernels.✅")
# Only benchmark backward here
def torch_bwd():
O_ref.backward(dO, retain_graph=True)
def tl_bwd():
O.backward(dO, retain_graph=True)
latency = do_bench(torch_bwd, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(tl_bwd, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument("--h", type=int, default=64, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype)
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd.py
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=64,
block_N=64,
num_stages=1,
threads=128,
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: T.dtype = T.float16,
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
from typing import Optional
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
sm_scale=None,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: T.dtype = T.float16,
tune: bool = False,
):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
import tilelang.testing
import example_mha_sink_fwd_bhsd
import example_mha_sink_fwd_bhsd_wgmma_pipelined
import example_gqa_sink_fwd_bhsd_wgmma_pipelined
import example_mha_sink_bwd_bhsd
import example_gqa_sink_bwd_bhsd
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_full_attn():
example_mha_sink_fwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_sliding_window():
example_mha_sink_fwd_bhsd.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
@tilelang.testing.requires_cuda
def test_example_mha_sink_bwd_bhsd():
example_mha_sink_bwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_mha_sink_bwd_bhsd_sliding_window():
example_mha_sink_bwd_bhsd.main(window_size=128)
@tilelang.testing.requires_cuda
def test_example_gqa_sink_bwd_bhsd():
example_gqa_sink_bwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_gqa_sink_bwd_bhsd_sliding_window():
example_gqa_sink_bwd_bhsd.main(window_size=128)
if __name__ == "__main__":
tilelang.testing.main()
models/
\ No newline at end of file
---
license: mit
---
This is a Tilelang Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.
## Make Checkpoints for vLLM
We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.
```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B` directory
```
The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.
```bash
./maint/generate_bitnet_model_bitblas_format.sh ./models/ckpt_bitnet_b1_58-3B ./models/ckpt_bitnet_b1_58-3B_bitblas
# the output ckpy will be saved in the `./models/ckpt_bitnet_b1_58-3B_bitblas` directory
```
Finnaly, you can use the ckpt in vLLM with:
```bash
cd vllm_workspace
# inference with the ckpt with fp16 uncompressed metadata
python3 inference_with_native_format.py
# inference with the ckpt with BitBLAS compressed metadata
python3 inference_with_bitblas_format.py
```
**Benchmark results of vLLM**
| Model | Framework | BS16IN32OUT128 | BS1IN512OUT1024 | BS32IN32OUT128 |
|------------------------|--------------------------|----------------|-----------------|----------------|
| bitnet-3b-1.58bits | pytorch | 106.83 | 49.34 | 209.03 |
| bitnet-3b-1.58bits | pytorch-tilelang | 240.33 | 103.09 | 493.31 |
| bitnet-3b-1.58bits | vllm-tilelang | 379.25 | 117.43 | 752.55 |
| bitnet-3b-1.58bits | vllm-tilelang-cuda-graph | 2543.58 | 1621.08 | 2731.79 |
## BitBLAS Results
### Performance
**Note:** To reproduce the results of BitBLAS, Please checkout the `benchmark_inference_latency.py`. To reproduce the results of the original model, Please checkout the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) repo.
| Model | Device | batchsize | in_seq | model | bitnet-1.58b-3b-huggingface | bitnet-1.58b-3b-bitblas |
|:---------------:|:------:|:---------:|:------:|:--------:|:---------------------------:|:-----------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | LLAMA-3B | 177.6729107 | 64.17962909 |
| bitnet_b1_58-3B | A100 | 128 | 1 | LLAMA-3B | 188.6145592 | 63.48158518 |
| bitnet_b1_58-3B | A100 | 1 | 2048 | LLAMA-3B | 348.7066031 | 202.6877999 |
### On-the-Fly GPU Memory Footprint
We measured the GPU memory footprint through the `nvidia-smi` command. Please checkout `nvidia_measure_memory.sh` to get the real-time GPU memory usage. And then start a `benchmark_model_10k_loops.py` workload to measure the overall GPU memory usage.
| **Model** | **Device** | **batchsize** | **in_seq** | **bitnet-1.58b-3b-huggingface** | **bitnet-1.58b-3b-bitblas** |
|:---------------:|:----------:|:-------------:|:----------:|:-------------------------------:|:---------------------------:|
| bitnet_b1_58-3B | A100 | 1 | 1 | 7595 MB | 1729 MB |
| bitnet_b1_58-3B | A100 | 128 | 1 | 7677 MB | 1789 MB |
| bitnet_b1_58-3B | A100 | 1 | 2048 | 8731 MB | 3163 MB |
## PPL and Zero-shot Accuracy
The number is Reported from the [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B), Please checkout the `eval_ppl.py`.
PPL and zero-shot accuracy:
| Models | PPL| ARCe| ARCc| HS | BQ | OQ | PQ | WGe | Avg
|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| FP16 700M (reported) | 12.33 | 54.7 | 23.0 | 37.0 | 60.0 | 20.2 | 68.9 | 54.8 | 45.5 |
| BitNet b1.58 700M (reported) | 12.87 | 51.8 | 21.4 | 35.1 | 58.2 | 20.0 | 68.1 | 55.2 | 44.3 |
| BitNet b1.58 700M (reproduced) | 12.78 | 51.4 | 21.8 | 35.0 | 59.6 | 20.6 | 67.5 | 55.4 | 44.5 |
| FP16 1.3B (reported) | 11.25 | 56.9 | 23.5 | 38.5 | 59.1 | 21.6 | 70.0 | 53.9 | 46.2
| BitNet b1.58 1.3B (reported) | 11.29 | 54.9 | 24.2 | 37.7 | 56.7 | 19.6 | 68.8 | 55.8 | 45.4 |
| BitNet b1.58 1.3B (reproduced) | 11.19 | 55.8 | 23.7 | 37.6 | 59.0 | 20.2 | 69.2 | 56.0 | 45.9
| FP16 3B (reported) | 10.04 | 62.1 | 25.6 | 43.3 | 61.8 | 24.6 | 72.1 | 58.2 | 49.7
| BitNet b1.58 3B (reported) | 9.91 | 61.4 | 28.3 | 42.9 | 61.5 | 26.6 | 71.5 | 59.3 | 50.2
| BitNet b1.58 3B (reproduced) | 9.88 | 60.9 | 28.0 | 42.3 | 58.3 | 26.0 | 71.4 | 60.3 | 49.6 |
The differences between the reported numbers and the reproduced results are possibly variances from the training data processing, seeds, or other random factors.
## Citations
```bibtex
@article{ma2024era,
title={The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits},
author={Ma, Shuming and Wang, Hongyu and Ma, Lingxiao and Wang, Lei and Wang, Wenhui and Huang, Shaohan and Dong, Li and Wang, Ruiping and Xue, Jilong and Wei, Furu},
journal={arXiv preprint arXiv:2402.17764},
year={2024}
}
```
\ No newline at end of file
python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log
python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log
python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 | tee b32_i32_o128.log
python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b16_i32_o128_bitblas.log
python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 --bitblas | tee b1_i512_o64_bitblas.log
python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b32_i32_o128_bitblas.log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment