Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]:
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128)
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128)
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128)
return cfg, sm_version
......@@ -459,8 +459,9 @@ def main(batch: int = 1,
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16)
split = config["num_split"]
glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
......
This diff is collapsed.
......@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main():
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
......
......@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
def test_example_example_mha_inference():
example_mha_inference.main()
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
if __name__ == "__main__":
......
......@@ -7,8 +7,6 @@ import tilelang
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
......@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
......
......@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
- [References](#references)
---
......@@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi
### Prerequisites
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation
......@@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
### Code Walkthrough
1. **Define the Kernel Launch Configuration:**
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:**
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:**
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:**
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
......@@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:**
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
......@@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
```
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
---
......@@ -247,7 +250,7 @@ print("Results match!")
## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow
......@@ -394,10 +397,10 @@ def tl_matmul(
]
```
1. **Set Up Tile Sizes and Thread Bindings**
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments**
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
......@@ -406,7 +409,7 @@ def tl_matmul(
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`**
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
......@@ -418,7 +421,7 @@ def tl_matmul(
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction**
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
......@@ -429,7 +432,7 @@ def tl_matmul(
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`**
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared)
......@@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma
## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
......@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel(
T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
......@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel(
def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h
......
......@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel(
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)
T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)
for i in T.Pipelined(0, NT):
......
......@@ -21,7 +21,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
......@@ -22,7 +22,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
import tilelang.language as T
from typing import Literal, Callable
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
from tilelang.intrinsics.mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_16x16_to_local_64x4_layout_A,
shared_16x32_to_local_64x8_layout_A,
shared_16x64_to_local_64x16_layout_A,
)
def make_mfma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
k_dim: int = 16,
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
matrix : Literal["A", "B"]
The mfma operand to be loaded.
k_dim : int
The k dimension of the mfma.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
Describes how threads and indices in fragment are laid out.
"""
assert matrix in ["A", "B"], "matrix should be either A or B"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
return base_fragment
block_rows = 2
block_cols = 2
warp_rows = 2
warp_cols = 2
chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols],
repeat_on_thread=False,
lower_dim_first=False)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
# block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
lower_dim_first=True).replicate(block_cols)
print(block_layout)
plot_layout(block_layout, name="block_layout")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
......@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
print(jit_kernel.get_kernel_source())
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
......
......@@ -29,10 +29,7 @@ ALL_FILES=''
ONLY_CHANGED=''
FILES=()
if (($# == 0)); then
if [[ -n "$(git status --porcelain --ignore-submodules --untracked-files=no)" ]]; then
echo "Detected uncommitted changes. Please commit or stash them before running $0." >&2
exit 1
fi
# Default: allow dirty workspace; run on changed files (committed + worktree)
ONLY_CHANGED='true'
else
while (($# > 0)); do
......@@ -78,14 +75,17 @@ if [[ -n "${ALL_FILES}" ]]; then
echo "Checking all files..." >&2
elif [[ -n "${ONLY_CHANGED}" ]]; then
MERGE_BASE="$(get_merge_base)"
echo "Checking changed files compared to merge base (${MERGE_BASE})..." >&2
echo "Checking changed files vs merge base (${MERGE_BASE}) and working tree..." >&2
elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2
fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
python3 -m pip install pre-commit --user
fi
echo 'tile-lang pre-commit: Check Start'
......@@ -93,7 +93,17 @@ echo 'tile-lang pre-commit: Check Start'
if [[ -n "${ALL_FILES}" ]]; then
python3 -m pre_commit run --all-files
elif [[ -n "${ONLY_CHANGED}" ]]; then
python3 -m pre_commit run --from-ref "${MERGE_BASE}" --to-ref HEAD
# Collect changed files (committed since merge-base + current worktree)
CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" 2>/dev/null || true)"
if [[ -n "${CHANGED_FILES}" ]]; then
echo "Running pre-commit on changed files:"
echo "${CHANGED_FILES}"
# Convert newline-separated files to space-separated and run pre-commit once
CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')"
python3 -m pre_commit run --files ${CHANGED_FILES_SPACE}
else
echo "No files changed relative to merge base and worktree. Skipping pre-commit."
fi
elif [[ "${#FILES[@]}" -gt 0 ]]; then
python3 -m pre_commit run --files "${FILES[@]}"
fi
......@@ -105,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start'
if [[ -x "$(command -v run-clang-tidy)" ]]; then
# Check if clang-tidy is available
if [[ ! -x "$(command -v clang-tidy)" ]]; then
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt"
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user
fi
# Get clang-tidy version
CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')"
......
This diff is collapsed.
# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [64, 128]
N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
def _ensure_torch_dtypes(*dtype_names):
import torch
for name in set(dtype_names):
if not hasattr(torch, name):
pytest.skip(f"Torch does not expose dtype {name}")
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
TRANS_CASES = [
pytest.param(False, False, id="nn"),
pytest.param(False, True, id="nt"),
pytest.param(True, False, id="tn"),
pytest.param(True, True, id="tt"),
]
@pytest.fixture(scope="module", autouse=True)
def _setup_tilelang_environment():
tilelang.disable_cache()
tilelang.testing.set_random_seed(42)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_false_false(m, n, k):
run_gemm(
m,
n,
k * 3,
False,
False,
"float16",
"float16",
"float16",
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_false_false(m, n, k)
if __name__ == "__main__":
tilelang.testing.main()
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float32",
"float32",
id=f"K{k}-float16-float-float",
) for k in K_VALUES
] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
])
TRANS_CASES = [
pytest.param(False, True, id="nt"),
]
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
)
if __name__ == "__main__":
# tilelang.testing.main()
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
tilelang.disable_cache()
run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128)
run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print(f"Ref: {latency:.2f} ms")
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=500)
print(f"Tile-lang: {latency:.2f} ms")
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
tilelang.disable_cache()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment