Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -40,9 +40,9 @@ def get_heuristic_config() -> Tuple[Dict, int]:
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128)
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128)
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128)
return cfg, sm_version
......@@ -459,8 +459,9 @@ def main(batch: int = 1,
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16)
split = config["num_split"]
glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
......
import torch
import triton
import triton.language as tl
import math
import argparse
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
torch.manual_seed(0)
tilelang.disable_cache()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@triton.jit
def _fwd_inner(
q,
k_ptrs,
v_ptrs,
s_ptrs,
m_i,
l_i,
acc,
offs_h,
mask_h,
offs_n,
seqlen,
softmax_scale,
lo,
hi,
stride_kt,
stride_vt,
stride_sh,
stride_sn,
BLOCK_N: tl.constexpr,
):
"""Inner loop computation for attention"""
for blk_idx in tl.range(lo, hi):
start_n = blk_idx * BLOCK_N
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen)
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen)
qk = tl.dot(q, k)
qk *= softmax_scale
qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9)
row_max = tl.max(qk, 1)
tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h)
m_ij = tl.maximum(m_i, row_max)
qk -= m_ij[:, None]
p = tl.math.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
m_i = m_ij
acc *= alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
return m_i, l_i, acc
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [4, 8]\
for num_stages in [2, 4]\
],
key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'],
)
@triton.jit
def _fwd_kernel_varlen(
Q, # [token_q = b, h_q, dim]
K, # [token_k, h_kv, dim]
V,
O,
S,
s_aux,
softmax_scale,
cu_seqlens_k,
stride_qt,
stride_qh,
stride_qd,
stride_kt,
stride_kh,
stride_kd,
stride_vt,
stride_vh,
stride_vd,
stride_ot,
stride_oh,
stride_od,
stride_sb,
stride_sh,
stride_sn, #bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_z = tl.program_id(0)
off_h_for_kv = tl.program_id(1)
off_h_q = off_h_for_kv * gqa_group_size
cu_k_start = tl.load(cu_seqlens_k + off_z)
cu_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_k_end - cu_k_start
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh
K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh
V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh
O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh
S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh
mask_h = offs_h < gqa_group_size
q = tl.load(
Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
if s_aux is not None:
sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32)
l_i = tl.zeros([BLOCK_H], dtype=tl.float32)
m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink
else:
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N)
m_i, l_i, acc = _fwd_inner(
q,
k_ptrs,
v_ptrs,
S_ptrs,
m_i,
l_i,
acc,
offs_h,
mask_h,
offs_n,
seqlen_k,
softmax_scale,
lo,
hi,
stride_kt,
stride_vt,
stride_sh,
stride_sn,
BLOCK_N,
)
if s_aux is not None:
sink = tl.math.exp(sink - m_i)
l_i = l_i + sink
acc = acc / l_i[:, None]
else:
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
for blk_idx in tl.range(lo, hi):
s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h)
s = tl.exp(s - m_i) / l_i
tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h)
acc = acc.to(O.dtype.element_ty)
tl.store(
O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od,
acc,
mask=mask_h[:, None])
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch,
heads,
k_heads,
max_seqlen_kv,
total_seqlen_k,
dim,
has_sink,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // k_heads
valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], "float32")
T.annotate_layout({
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
cur_start_k = cu_seqlens_k[bid]
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :],
K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T.copy(scores_max, S_shared[:, k])
# scores_scale is alpha in triton
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :],
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h]
# T.copy(S_shared, S_fragment)
# for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
# S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
# T.copy(S_fragment, S_shared)
T.copy(S_shared[:valid_block_H, :], S[bid,
hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S)
# TODO: split version
return flashattn_gqa_decode_no_split
def flash_attn_with_attn_pool_decode_tilelang(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
tl_kernel=None,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)),
dtype=Q.dtype,
device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
if use_per_kv_head_sparse_index:
S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1))
return O_tl, S_tl
def flash_attn_with_attn_pool_decode(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
BLOCK_D = head_size
BLOCK_N = block_size
BLOCK_H = 64
O = torch.zeros_like(Q)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)),
dtype=Q.dtype,
device=Q.device)
def grid(META):
return (batch, k_h)
with torch.cuda.device(Q.device.index):
_fwd_kernel_varlen[grid](
Q,
K,
V,
O,
S,
s_aux,
softmax_scale,
cu_seqlens_k,
*Q.stride(),
*K.stride(),
*V.stride(),
*O.stride(),
*S.stride(),
gqa_group_size,
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
BLOCK_D=BLOCK_D,
)
if use_per_kv_head_sparse_index:
S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1))
return O, S
def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen # Use as max sequence length
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
print(f"Actual max_seqlen_k: {max_seqlen_k}")
print(f"q_decode shape: {q_decode.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
v_padded_list = []
for i in range(batch_size):
actual_k_len = k_seqlens[i]
# Extract and pad k, v for this batch
k_start = cu_seqlens_k[i]
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
k_padded_list.append(k_padded)
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
print(f"q_expanded shape: {q_expanded.shape}")
print(f"k_padded_batched shape: {k_padded_batched.shape}")
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf')
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf')
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
print(f"O_torch shape: {O_torch.shape}")
print(f"S_triton shape: {S_triton.shape}")
print(f"S_tilelang shape: {S_tilelang.shape}")
print(f"attn_score_pooled shape: {attn_score_pooled.shape}")
# Compare results
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
print("✅ All tests passed!")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def speed_benchmark_decode_comparison(args):
"""Speed benchmark for decode kernel"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
print(f" Batch size: {batch_size}")
print(f" Q heads: {q_heads}, KV heads: {kv_heads}")
print(f" Max K sequence length: {max_k_seqlen}")
print(f" Head size: {head_size}")
print(f" Block size: {block_size}")
print(f" Data type: {dtype}")
print(f" Variable lengths: {args.test_varlen}")
print(f" s_aux attention: {args.test_sink}")
print()
# Generate input data
if args.test_varlen:
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
else:
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
print(f" Total K tokens: {total_k_tokens}")
print(f" Actual max K seq len: {max_seqlen_k}")
if args.test_varlen:
print(f" K sequence lengths: {k_seqlens.tolist()}")
# Warmup
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
# Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
tilelang_time = do_bench(
flash_attn_with_attn_pool_decode_tilelang,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
False,
tl_kernel,
)
print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms")
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen,
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink,
block_size)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads')
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads')
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length')
parser.add_argument(
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension')
parser.add_argument('--block_size', type=int, default=64, help='Block size for computation')
parser.add_argument(
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type')
parser.add_argument(
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths')
parser.add_argument(
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
args = parser.parse_args()
args.test_sink = True
args.test_varlen = False
args.dtype = 'float16'
args.num_split = 1
if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
......@@ -302,9 +302,7 @@ def flash_split_ref(Q, K, V, causal):
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main():
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
......
......@@ -12,7 +12,7 @@ def test_example_example_gqa_decode():
def test_example_example_mha_inference():
example_mha_inference.main()
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
if __name__ == "__main__":
......
......@@ -7,8 +7,6 @@ import tilelang
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
......@@ -256,8 +254,9 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
......
......@@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
- [References](#references)
---
......
......@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel(
T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
......@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel(
def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h
......
......@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel(
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)
T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)
for i in T.Pipelined(0, NT):
......
......@@ -21,7 +21,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
......@@ -22,7 +22,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for k in range(num_k_step):
# reverse, better cache hit rate
......@@ -51,7 +51,7 @@ def rms_norm(M, N, blk_m):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12)
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])
......
import tilelang.language as T
from typing import Literal, Callable
from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
from tilelang.intrinsics.mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_16x16_to_local_64x4_layout_A,
shared_16x32_to_local_64x8_layout_A,
shared_16x64_to_local_64x16_layout_A,
)
def make_mfma_load_base_layout(dtype: str = "float16",
matrix: Literal["A", "B"] = "A",
k_dim: int = 16,
transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mfma_store_layout` to
map fragment indices to threads and local indices.
Parameters
----------
dtype : str
The data type of the matrix.
matrix : Literal["A", "B"]
The mfma operand to be loaded.
k_dim : int
The k dimension of the mfma.
transposed : bool
Whether the matrix is transposed, by default False.
Returns
-------
T.Fragment
Describes how threads and indices in fragment are laid out.
"""
assert matrix in ["A", "B"], "matrix should be either A or B"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if k_dim == 4:
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
elif k_dim == 16:
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
elif k_dim == 32:
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
elif k_dim == 64:
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
else:
raise ValueError(f"Unsupported matrix {matrix}")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
return lane_id
def forward_index(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
"""
_, local_id = inverse_mma_load_layout.map_indices([i, j])
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
return base_fragment
block_rows = 2
block_cols = 2
warp_rows = 2
warp_cols = 2
chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
# warp layout 32x32
warp_layout = base_layout.repeat([warp_rows, warp_cols],
repeat_on_thread=False,
lower_dim_first=False)
print(warp_layout)
plot_layout(warp_layout, name="warp_layout")
# block layout 64x32
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
lower_dim_first=True).replicate(block_cols)
print(block_layout)
plot_layout(block_layout, name="block_layout")
import tilelang
import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
......@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
print(jit_kernel.get_kernel_source())
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
......
......@@ -29,10 +29,7 @@ ALL_FILES=''
ONLY_CHANGED=''
FILES=()
if (($# == 0)); then
if [[ -n "$(git status --porcelain --ignore-submodules --untracked-files=no)" ]]; then
echo "Detected uncommitted changes. Please commit or stash them before running $0." >&2
exit 1
fi
# Default: allow dirty workspace; run on changed files (committed + worktree)
ONLY_CHANGED='true'
else
while (($# > 0)); do
......@@ -78,14 +75,17 @@ if [[ -n "${ALL_FILES}" ]]; then
echo "Checking all files..." >&2
elif [[ -n "${ONLY_CHANGED}" ]]; then
MERGE_BASE="$(get_merge_base)"
echo "Checking changed files compared to merge base (${MERGE_BASE})..." >&2
echo "Checking changed files vs merge base (${MERGE_BASE}) and working tree..." >&2
elif [[ "${#FILES[@]}" -gt 0 ]]; then
echo "Checking specified files: ${FILES[*]}..." >&2
fi
# Some systems set pip's default to --user, which breaks isolated virtualenvs.
export PIP_USER=0
# If pre-commit is not installed, install it.
if ! python3 -m pre_commit --version &>/dev/null; then
python3 -m pip install pre-commit
python3 -m pip install pre-commit --user
fi
echo 'tile-lang pre-commit: Check Start'
......@@ -93,7 +93,17 @@ echo 'tile-lang pre-commit: Check Start'
if [[ -n "${ALL_FILES}" ]]; then
python3 -m pre_commit run --all-files
elif [[ -n "${ONLY_CHANGED}" ]]; then
python3 -m pre_commit run --from-ref "${MERGE_BASE}" --to-ref HEAD
# Collect changed files (committed since merge-base + current worktree)
CHANGED_FILES="$(git diff --name-only --diff-filter=ACM "${MERGE_BASE}" 2>/dev/null || true)"
if [[ -n "${CHANGED_FILES}" ]]; then
echo "Running pre-commit on changed files:"
echo "${CHANGED_FILES}"
# Convert newline-separated files to space-separated and run pre-commit once
CHANGED_FILES_SPACE="$(echo "${CHANGED_FILES}" | tr '\n' ' ')"
python3 -m pre_commit run --files ${CHANGED_FILES_SPACE}
else
echo "No files changed relative to merge base and worktree. Skipping pre-commit."
fi
elif [[ "${#FILES[@]}" -gt 0 ]]; then
python3 -m pre_commit run --files "${FILES[@]}"
fi
......@@ -105,7 +115,7 @@ echo 'tile-lang clang-tidy: Check Start'
if [[ -x "$(command -v run-clang-tidy)" ]]; then
# Check if clang-tidy is available
if [[ ! -x "$(command -v clang-tidy)" ]]; then
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt"
python3 -m pip install --upgrade --requirements "${ROOT}/requirements-lint.txt" --user
fi
# Get clang-tidy version
CLANG_TIDY_VERSION="$(clang-tidy --version | head -n1 | awk '{print $4}')"
......
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_frag)
T.gemm_v2(A_shared, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_sr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.copy(B_shared, B_frag)
T.gemm_v2(A_frag, B_frag, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [pytest.param(
k,
"int8",
"int32",
"int32",
id="K32-int8-int32-int32",
) for k in K_VALUES_8Bit] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
] + [
pytest.param(
k,
"float8_e4m3",
"float32",
"float32",
id="K32-float8_e4m3-float32-float32",
) for k in K_VALUES_8Bit
])
def _ensure_torch_dtypes(*dtype_names):
import torch
for name in set(dtype_names):
if not hasattr(torch, name):
pytest.skip(f"Torch does not expose dtype {name}")
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_false(m, n, k):
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_true(m, n, k):
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_sr_false_false(m, n, k):
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_false(m, n, k):
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_true(m, n, k):
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rr_false_false(m, n, k):
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_false(m, n, k):
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_true(m, n, k):
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
TRANS_CASES = [
pytest.param(False, False, id="nn"),
pytest.param(False, True, id="nt"),
pytest.param(True, False, id="tn"),
pytest.param(True, True, id="tt"),
]
@pytest.fixture(scope="module", autouse=True)
def _setup_tilelang_environment():
tilelang.disable_cache()
tilelang.testing.set_random_seed(42)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_false_false(m, n, k):
run_gemm(
m,
n,
k * 3,
False,
False,
"float16",
"float16",
"float16",
m,
n,
k,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_true_false(m, n, k):
run_gemm(
m,
n,
k * 3,
True,
False,
"float16",
"float16",
"float16",
m,
n,
k,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_true_true(m, n, k):
run_gemm(
m,
n,
k * 3,
True,
True,
"float16",
"float16",
"float16",
m,
n,
k,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_false_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_true_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_true(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_true_true(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_sr_false_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_sr_true_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_sr_true_true(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rr_false_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rr_true_false(m, n, k)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rr_true_true(m, n, k)
if __name__ == "__main__":
tilelang.testing.main()
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True False =============================")
# run_gemm(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
# # Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True True =============================")
# run_gemm(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
# Test Pass
# for m in [64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# pytest maint/gemm_v2/correctness_evaluation_sm70.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared.dyn")
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_frag)
T.gemm_v2(A_frag, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [64, 128]
N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
def _ensure_torch_dtypes(*dtype_names):
import torch
for name in set(dtype_names):
if not hasattr(torch, name):
pytest.skip(f"Torch does not expose dtype {name}")
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
TRANS_CASES = [
pytest.param(False, False, id="nn"),
pytest.param(False, True, id="nt"),
pytest.param(True, False, id="tn"),
pytest.param(True, True, id="tt"),
]
@pytest.fixture(scope="module", autouse=True)
def _setup_tilelang_environment():
tilelang.disable_cache()
tilelang.testing.set_random_seed(42)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_false_false(m, n, k):
run_gemm(
m,
n,
k * 3,
False,
False,
"float16",
"float16",
"float16",
m,
n,
k,
2,
128,
)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
_ensure_torch_dtypes(in_dtype, out_dtype, accum_dtype)
run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype)
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
run_gemm_rs_false_false(m, n, k)
if __name__ == "__main__":
tilelang.testing.main()
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [64, 128]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float32",
"float32",
id=f"K{k}-float16-float-float",
) for k in K_VALUES
] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
])
TRANS_CASES = [
pytest.param(False, True, id="nt"),
]
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
)
if __name__ == "__main__":
# tilelang.testing.main()
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
tilelang.disable_cache()
run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128)
run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print(f"Ref: {latency:.2f} ms")
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=500)
print(f"Tile-lang: {latency:.2f} ms")
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
tilelang.disable_cache()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
./maint/scripts/docker_local_distribute.sh 2>&1 | tee docker_local_distribute.log
./maint/scripts/docker_pypi_distribute.sh 2>&1 | tee docker_pypi_distribute.log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment