Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
import torch
import triton
import triton.language as tl
import argparse
from einops import rearrange, einsum
import torch.nn.functional as F
import math
import time
from heuristic import num_splits_heuristic
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
)
@triton.jit
def _split_kernel(
q_ptr,
k_cache_ptr,
v_cache_ptr,
cache_seqlens_ptr,
o_partial_ptr,
lse_partial_ptr,
mask_ptr,
sm_scale,
num_splits,
gqa_group_size,
stride_q_b,
stride_q_h,
stride_q_d,
stride_k_b,
stride_k_s,
stride_k_h,
stride_k_d,
stride_v_b,
stride_v_s,
stride_v_h,
stride_v_d,
stride_o_b,
stride_o_h,
stride_o_split,
stride_o_d,
stride_lse_b,
stride_lse_h,
stride_lse_split,
stride_mask_b,
stride_mask_h,
stride_mask_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx_kv = tl.program_id(1)
split_idx = tl.program_id(2)
head_idx_q = head_idx_kv * gqa_group_size
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx)
num_blocks = (cache_seqlens + BLOCK_N - 1) // BLOCK_N
blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32)
remaining_blocks = num_blocks % num_splits
if split_idx < remaining_blocks:
loop_range = blocks_per_split + 1
else:
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[
None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for block_idx in range(loop_range):
start_n = (start + block_idx) * BLOCK_N
mask_val = tl.load(mask_ptr + (start + block_idx) * stride_mask_s)
if mask_val == 1:
k_ptr = k_cache_ptr + start_n * stride_k_s
v_ptr = v_cache_ptr + start_n * stride_v_s
k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0)
v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0)
qk = tl.dot(q, k)
qk = qk * sm_scale
qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
m_i = m_ij
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + (
head_idx_q +
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
)
@triton.jit
def _merge_kernel(
o_partial_ptr,
lse_partial_ptr,
o_ptr,
lse_partial_stride_b,
lse_partial_stride_h,
lse_partial_stride_split,
o_partial_stride_b,
o_partial_stride_h,
o_partial_stride_split,
o_partial_stride_d,
o_stride_b,
o_stride_h,
o_stride_d,
BLOCK_D: tl.constexpr,
num_splits: tl.constexpr,
num_splits_pow2: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
offs_splits = tl.arange(0, num_splits_pow2)
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split +
offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
acc = numerator_normalized / sumexp_normalized
acc = acc.to(o_ptr.dtype.element_ty)
o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h
tl.store(o_ptr + offs_d * o_stride_d, acc)
def block_sparse_flash_decode_gqa_mask_triton(
q,
k_cache,
v_cache,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
sm_scale=None,
):
batch, heads, dim = q.shape
if sm_scale is None:
sm_scale = 1 / math.sqrt(dim)
_, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape
assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch"
group_size = heads // heads_kv
block_H = 16
max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2 = triton.next_power_of_2(num_splits)
o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype)
lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32)
BLOCK_D = dim
BLOCK_H = group_size if group_size > 16 else 16
grid = (batch, heads_kv, num_splits)
_split_kernel[grid](
q,
k_cache,
v_cache,
cache_seqlens,
o_partial,
lse_partial,
block_mask,
sm_scale,
num_splits,
group_size,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
block_mask.stride(0),
block_mask.stride(1),
block_mask.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=block_size,
BLOCK_D=BLOCK_D,
)
output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype)
grid = (batch, heads)
_merge_kernel[grid](
o_partial,
lse_partial,
output,
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
output.stride(0),
output.stride(1),
output.stride(2),
BLOCK_D=dim_v,
num_splits=num_splits,
num_splits_pow2=num_splits_pow2,
)
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
for b in range(batch):
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
block_size = block_size
sparse_ratio = sparse_ratio
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
num_blocks = (max_cache_seqlen + block_size - 1) // block_size
valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int()
print("valid_num_blocks: ", valid_num_blocks)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block]
block_mask[b, h, perm] = True
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
triton_out = block_sparse_flash_decode_gqa_mask_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
block_sparse_flash_decode_gqa_mask_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
)
torch.cuda.synchronize()
end = time.time()
elapsed_time = end - start
avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds")
print(f"Average flops: {avg_flops:.2f} GFLOPS")
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
ref_program_fa(Q, K, V, cache_seqlens)
torch.cuda.synchronize()
end = time.time()
elapsed_time_ref = end - start
avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head,
is_causal_or_local, max_splits):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Parameters:
- total_mblocks (int): Total number of m_blocks.
- num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU.
- num_n_blocks (int): Number of n_blocks.
- num_m_blocks (int): Number of m_blocks.
- size_one_kv_head (int): Size of one KV head in bytes.
- is_causal_or_local (bool): Indicates whether the operation is causal or local.
- max_splits (int): Maximum number of allowed splits.
Returns:
- int: The optimal number of splits.
"""
# If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply.
if total_mblocks >= 0.8 * num_SMs:
size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB)
# Only split if each KV head is too large for L2 and there are enough m_blocks
if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local:
return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
else:
return 1
# If num_n_blocks is too small, we don't split
if num_n_blocks <= 4:
return 1
# Limit max_splits to a reasonable range
max_splits = min(max_splits, num_SMs, num_n_blocks)
max_efficiency = 0.0
efficiency = []
# Compute efficiency for different splits
for num_splits in range(1, max_splits + 1):
n_waves = (total_mblocks * num_splits) / num_SMs
eff = n_waves / math.ceil(n_waves)
# Track max efficiency
if eff > max_efficiency:
max_efficiency = eff
efficiency.append(eff)
# Find the smallest number of splits that achieves at least 85% of max efficiency
for num_splits in range(1, max_splits + 1):
if efficiency[num_splits - 1] >= 0.85 * max_efficiency:
return num_splits
return 1
import tilelang.testing
import block_sparse_attn_triton
import example_tilelang_block_sparse_attn
import example_tilelang_sparse_gqa_decode_varlen_indice
import example_tilelang_sparse_gqa_decode_varlen_mask
import example_triton_sparse_gqa_decode_varlen_indice
import example_triton_sparse_gqa_decode_varlen_mask
def test_block_sparse_attn_triton():
block_sparse_attn_triton.main()
def test_example_tilelang_block_sparse_attn():
example_tilelang_block_sparse_attn.main()
def test_example_tilelang_sparse_gqa_decode_varlen_indice():
example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048)
def test_example_tilelang_sparse_gqa_decode_varlen_mask():
example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048)
def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16,
heads=16,
heads_kv=8,
max_cache_seqlen=4096,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32)
if __name__ == "__main__":
tilelang.testing.main()
import argparse
import itertools
import tilelang
import tilelang.language as T
from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
import torch
from typing import List
DEFAULT_BLOCK_M = 128
DEFAULT_BLOCK_N = 128
DEFAULT_BLOCK_K = 32
DEFAULT_NUM_STAGES = 2
DEFAULT_THREAD_NUM = 128
DEFAULT_ENABLE_RASTERIZATION = True
parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k
sparsity = args.sparsity
use_autotune = args.use_autotune
default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto)
print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}")
print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n")
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5],
} for c in _configs]
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += (
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c
def supply_program(params: List[KernelParam]):
input_tensors = []
for p in params:
# Check if the kernel parameter is BlockMask tensor.
# Here, BlockMask is uniquely identified by having 3 dimensions.
if len(p.shape) != 3:
# For non-BlockMask tensors, use the default tensor generation logic.
input_tensors.append(default_tensor_supply(p))
else:
# For BlockMask tensor, randomly set elements to True based on desired
# sparsity level.
block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device())
block_mask[:, :, :] = torch.rand(p.shape) > sparsity
input_tensors.append(block_mask)
return input_tensors
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if BlockMask[by, bx, k]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return block_sparse_matmul
def main():
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
if args.use_autotune:
# Run the autotuner to find the best kernel configuration and performance
# get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency.
kernel = blocksparse_matmul(M, N, K)
best_config = kernel.config
best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[
"block_K"]
print(f"Best Config: {best_config}")
print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms")
else:
kernel = blocksparse_matmul(
M,
N,
K,
block_M=DEFAULT_BLOCK_M,
block_N=DEFAULT_BLOCK_N,
block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
try:
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("✅ Results are close! Verification successful.")
except AssertionError as e:
print("❌ Verification FAILED: Results differ significantly.")
print(e)
if __name__ == "__main__":
main()
import tilelang.testing
import example_blocksparse_gemm
def test_example_blocksparse_gemm():
example_blocksparse_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16
dtype = "bfloat16"
accum_dtype = "float"
@tilelang.jit(out_idx=[2, 3])
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def group_per_split_token_cast(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor(
(BG,), "int32"), X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_amax: T.Tensor(
(BG, M_max, T.ceildiv(N, group_size)), accum_dtype)):
with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx
row_g_id = by
bg = bz
y_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_fragment((1,), "int32")
T.annotate_layout({
y_local:
T.Fragment(
y_local.shape,
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
})
row_offset[0] = 0
for i in T.serial(bg):
row_offset[0] += batch_sizes[i]
T.copy(
X[row_offset[0] + row * blk_m:row_offset[0] + (row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size], y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg],
y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i, j in T.Parallel(blk_m, group_size):
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg],
y_q_local[i, j], 0)
for i in T.Parallel(blk_m):
X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(
y_q_local_fp8, X_fp8[bg, row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return group_per_split_token_cast
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# this function don't support cpu tensor
assert x.dim() == 2
m, n = x.shape
new_n = ceil_div(n, 128) * 128
x_padded = torch.nn.functional.pad(x, (0, new_n - n))
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
# assert x.shape[0] == batch_sizes.sum()
M_max = ceil_div(batch_sizes.max(), 128) * 128
split_x = torch.split(x, batch_sizes.tolist(), dim=0)
padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x]
num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1]
x_fp8 = (torch.empty((num_groups, m, n), device='cuda', dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i])
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8):
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == "bfloat16":
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes)
print("M_max:", M_max)
kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
print(kernel.get_kernel_source())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x_fp8, x_amax = kernel(x, batch_sizes)
x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes)
torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
print("All checks pass.")
from tilelang.profiler import do_bench
def run_tilelang():
x_fp8_tilelang_, x_amax_tilelang_ = kernel(x, batch_sizes)
return x_fp8_tilelang_, x_amax_tilelang_
def run_torch():
x_fp8_torch_, x_amax_torch_ = ref_program(x, batch_sizes)
return x_fp8_torch_, x_amax_torch_
latency = do_bench(run_tilelang)
print("Tile-lang: {:.2f} ms".format(latency))
latency = do_bench(run_torch)
print("Torch: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float"
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def per_token_cast(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"),
X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
row_g_id = by
y_local = T.alloc_fragment((blk_m, group_size), dtype)
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
T.annotate_layout({
y_local:
T.Fragment(
y_local.shape,
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
})
T.copy(
X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = y_amax_local[i] / fp8_max
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return per_token_cast
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# this function don't support cpu tensor
assert x.dim() == 2
m, n = x.shape
new_n = ceil_div(n, 128) * 128
x_padded = torch.nn.functional.pad(x, (0, new_n - n))
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
def main(M=8192, N=8192, blk_m=8):
kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
x_fp8, x_amax = kernel(x)
x_fp8_ref, x_amax_ref = ref_program(x)
print("x_fp8:", x_fp8, x_fp8.shape)
print("x_amax:", x_amax, x_amax.shape)
print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape)
print("x_amax_ref:", x_amax_ref, x_amax_ref.shape)
torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench()
print("Tile-lang: {:.2f} ms".format(latency))
from tilelang.profiler import do_bench
from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(
x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton()
latency = do_bench(run_triton)
print("Triton: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] %
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
import tilelang.testing
import example_group_per_split_token_cast_to_fp8
import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=2, blk_m=8)
def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=8192, N=2048, blk_m=8)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang
import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(
func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = jit_kernel(a, b)
print(c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
import os
import random
import pytest
os.environ["PYTHONHASHSEED"] = "0"
random.seed(0)
try:
import torch
except ImportError:
pass
else:
torch.manual_seed(0)
try:
import numpy as np
except ImportError:
pass
else:
np.random.seed(0)
def pytest_terminal_summary(terminalreporter, exitstatus, config):
"""Ensure that at least one test is collected. Error out if all tests are skipped."""
known_types = {
"failed",
"passed",
"skipped",
"deselected",
"xfailed",
"xpassed",
"warnings",
"error",
}
if (sum(
len(terminalreporter.stats.get(k, []))
for k in known_types.difference({"skipped", "deselected"})) == 0):
terminalreporter.write_sep(
"!",
(f"Error: No tests were collected. "
f"{dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
)
pytest.exit("No tests were collected.", returncode=5)
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
@tilelang.jit(out_idx=[2])
def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
block_m = 64
block_n = 128
block_k = 32
num_stages = 3
threads = 256
kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads)
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if __name__ == "__main__":
main()
import torch
import argparse
import itertools
import tilelang
import tilelang.language as T
def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2])
def convolution(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper:
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
})
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
if is_hopper:
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
else:
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main
def main(n: int = 128,
c: int = 128,
h: int = 64,
w: int = 64,
f: int = 128,
k: int = 3,
s: int = 1,
d: int = 1,
p: int = 1,
use_autotune: bool = False,
with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
if use_autotune:
kernel = convolution(N, C, H, W, F, K, S, D, P)
else:
config = get_heuristic_config()
kernel = convolution(N, C, H, W, F, K, S, D, P, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_prog)
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument('--n', type=int, default=128, help='n')
parser.add_argument('--c', type=int, default=128, help='c')
parser.add_argument('--h', type=int, default=64, help='h')
parser.add_argument('--w', type=int, default=64, help='w')
parser.add_argument('--f', type=int, default=128, help='f')
parser.add_argument('--k', type=int, default=3, help='k')
parser.add_argument('--s', type=int, default=1, help='s')
parser.add_argument('--d', type=int, default=1, help='d')
parser.add_argument('--p', type=int, default=1, help='p')
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune,
args.with_roller)
import tilelang.testing
import example_convolution
import example_convolution_autotune
# TODO(@cy): TMA with convolution must be fixed in future.
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_convolution():
example_convolution.main([])
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_convolution_autotune():
example_convolution_autotune.main()
if __name__ == "__main__":
tilelang.testing.main()
from typing import Tuple
import torch
import tilelang.testing
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(42)
@tilelang.jit
def tl_gemm(
M,
N,
K,
block_N,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float8_e4m3",
], "Currently only float8_e4m3 is supported"
assert out_dtype in [
"bfloat16",
"float32",
], "Currently only float16 and float32 are supported"
group_size = 128
block_M = 128
block_K = 128
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, group_size))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32")
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def ceildiv(a, b):
return (a + b - 1) // b
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
# A_scale: (M, K//128) ==> (M//128, K//128, 128)
# B_scale: (N//128, K//128) ==> (N//128, K//128, 128)
# A_fp8: (M, K)
# B_fp8: (N, K)
# out_dtype: float16 or float32
# return C: (M, N)
M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
for i in range(ceildiv(M, 128)):
for j in range(ceildiv(N, 128)):
c_acc.zero_()
for k in range(ceildiv(K, 128)):
c = torch._scaled_mm(
A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128],
B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16)
c_acc += c.to(torch.float32)
C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype)
return C
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
A = torch.randn(M, K).to(torch.bfloat16).cuda()
B = torch.randn(N, K).to(torch.bfloat16).cuda()
A_fp8, A_scale = per_token_cast_to_fp8(A.clone())
B_fp8, B_scale = per_block_cast_to_fp8(B.clone())
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
kernel(A_fp8, B_fp8, C, A_scale, B_scale)
# Get Reference Result
ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
diff = calc_diff(C, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
profiler = kernel.get_profiler()
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
print(f"latency: {latency} ms")
tflops = 2 * M * N * K / latency / 1e9
print(f"tflops: {tflops}")
def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32")
if __name__ == "__main__":
for dtype in ["float8_e4m3"]:
for out_dtype in ["bfloat16", "float32"]:
for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
import tilelang.testing
from example_deepgemm_fp8_2xAcc import main
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_deepgemm_fp8_2xAcc():
main()
if __name__ == "__main__":
tilelang.testing.main()
# 🚀 How to write high-performance kernel with TileLang: take MLA as an example
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
<figure style="text-align: center">
<a href="./figures/bs64_float16.png">
<img src="./figures/bs64_float16.png" alt="bs64_float16">
</a>
<figcaption style="text-align: center;">Figure 1:Performance under batch size=64</figcaption>
</figure>
<figure style="text-align: center">
<a href="./figures/bs128_float16.png">
<img src="./figures/bs128_float16.png" alt="bs128_float16">
</a>
<figcaption style="text-align: center;">Figure 2:Performance under batch size=128</figcaption>
</figure>
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
## Implementation
First, let's review the core computation logic of traditional FlashAttention:
```python
# acc_s: [block_M, block_N]
# scores_max: [block_M]
# scores_scale: [block_M]
# acc_o: [block_M, dim]
for i in range(loop_range):
acc_s = Q @ K[i]
scores_max_prev = scores_max
scores_max = max(acc_s, dim=1)
scores_scale = exp(scores_max_prev - scores_max)
acc_o *= scores_scale
acc_s = exp(acc_s - scores_max)
acc_o = acc_s @ V[i]
...
```
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
### Layout Inference
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
<figure style="text-align: center">
<a href="./figures/qk_layout.jpg">
<img src="./figures/qk_layout.jpg" alt="QK Layout">
</a>
<figcaption style="text-align: center;">Figure 3:Buffer shapes in Q @ K</figcaption>
</figure>
<figure style="text-align: center">
<a href="./figures/pv_layout.jpg">
<img src="./figures/pv_layout.jpg" alt="PV Layout">
</a>
<figcaption style="text-align: center;">Figure 4:Buffer shapes in acc_s @ V</figcaption>
</figure>
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
### Threadblock Swizzling
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
```python
T.use_swizzle(panel_size: int, order: str = "row")
```
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
### Shared Memory Swizzling
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
```python
T.annotate_layout({
S_shared: TileLang.layout.make_swizzled_layout(S_shared),
})
```
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
### Pipeline
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
```python
T.pipelined(range: int, stage: int)
```
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
### Split-KV
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
\ No newline at end of file
# 🚀 High-Performance FlashMLA Implementation Using TileLang on AMD MI300X Accelerators
Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms.
## Architectural Considerations and Optimization Strategies
Key implementation differences between Hopper and MI300X architectures include:
1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations.
2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes:
- Reducing software pipeline stages
- Register-based caching of Q matrices instead of shared memory utilization:
```python
# Original shared memory allocation
Q_shared = T.alloc_shared([block_H, dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
# Optimized register allocation
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
```
3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64.
4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code.
## Performance Evaluation
We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate:
<figure style="text-align: center">
<a href="../figures/flashmla-amd.png">
<img src="../figures/flashmla-amd.png" alt="AMD FlashMLA Performance Comparison">
</a>
<figcaption style="text-align: center;">Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)</figcaption>
</figure>
Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) (from 0.73x to 1.21x) in most test cases, while significantly outperforming Triton (up to 6.5x faster)implementations. This performance is achieved through a concise 70-line Python implementation!
## Future Optimization Opportunities
1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction.
2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to:
- Reduce shared memory pressure
- Improve compute-to-memory access ratios
- Enhance parallelism through dimension-wise task distribution
## Acknowledgment
We would like to express our sincere gratitude to the AMD ROCm and Composable Kernel team for their outstanding contributions. We have learned a great deal from the ROCm software stack.
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
threads = [128, 256]
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
} for c in _configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashmla_decode(batch,
heads,
kv_head_num,
seqlen_kv,
dim,
pe_dim,
block_N,
block_H,
num_split,
threads=128):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by):
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# T.copy(acc_s, S_shared)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split,
threads=threads) as (bx, by, bz):
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=0):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument('--autotune', action='store_true', help='auto tune')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.autotune
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 32
BLOCK_H = 64
num_split = 4
threads = 128
if enable_autotune:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else:
kernel = flashmla_decode(
batch,
heads,
kv_heads,
kv_ctx,
dim,
pe_dim,
BLOCK_N,
BLOCK_H,
num_split,
threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
ref_output = ref_program(*input_tensors)
print(f"Tilelang output: {tilelang_output}")
print(f"Ref output: {ref_output}")
torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment