Commit 2add9fa3 authored by wangkx1's avatar wangkx1
Browse files

add tilelang

parent f5bc26c2
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda()
return K, V, Beta, G, A
def prepare_output(
B,
S,
H,
DK,
DV,
output_dtype,
):
W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return W, U
@tilelang.jit(out_idx=[-2, -1])
def tilelang_recompute_w_u_fwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
}
)
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared)
T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared)
T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
block_DK,
block_DV,
threads,
num_stages,
):
K, V, Beta, G, A = prepare_input(
B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
)
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# reference
W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages,
)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
try:
torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute w passed √")
except Exception as e:
print("tilelang recompute w failed ✗")
print(e)
try:
torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute u passed √")
except Exception as e:
print("tilelang recompute u failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
block_DK=64,
block_DV=32,
threads=128,
num_stages=3,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id 00000000
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
torch.set_printoptions(profile="full")
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.ones(B, S, H, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
V = F.normalize(V, dim=-1, p=2)
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, dtype=gate_dtype).cuda()
return dk, dv, dbeta, dg
@tilelang.jit(
out_idx=[-5, -4, -3, -2, -1],
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def tilelang_wy_fast_bwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dg_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
# output
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta: T.Tensor(dbeta_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta_g = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
V_shared_beta = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
dw_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
du_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta_g = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_beta = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_v = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_reduce_tmpv = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
T.use_swizzle(10)
T.clear(dA_fragment)
T.clear(dk_fragment)
T.clear(dk_fragment_beta_g)
T.clear(dv_fragment)
T.clear(dv_fragment_beta)
T.clear(dbeta_fragment_k)
T.clear(dbeta_fragment_v)
T.clear(dg_fragment)
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
G_shared_exp[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
# Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = (
dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
)
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
# Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dv_fragment[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * Beta_shared[i_s]
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA
T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True})
def tilelang_wy_fast_bwd_split(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype),
dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype),
dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dA_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment_1 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dA_A_fragment_2 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
T.clear(dbeta_fragment_reduce_tmpk)
T.clear(dbeta_fragment_k)
T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s in T.Parallel(block_S):
G_shared_exp[i_s] = T.exp(G_shared[i_s])
# Load intermediate results
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
T.copy(dA_shared, dA_fragment)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True)
T.copy(dA_fragment, dA_shared)
T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2):
with T.Then():
dA_fragment[i_s1, i_s2] = 0
with T.Else():
dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh])
with T.Else():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
# acceptable dA diff
# T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dk using previous dk
T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(K_shared_beta, K_shared, A_fragment, transpose_B=True)
T.gemm(dA_shared, K_shared, dk_fragment_beta, clear_accum=True)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
# Update dg and dbeta
T.copy(A_fragment, A_shared)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dA_A_fragment[i_s1, i_s2] = dA_fragment[i_s1, i_s2] * A_fragment[i_s1, i_s2]
# Note: Reduce operation now not supported in shared memory
# FIXME: reduce will cause incorrect result when dim != -1
T.reduce_sum(dA_A_fragment, dA_A_fragment_1, dim=1)
T.reduce_sum(dA_A_fragment, dA_A_fragment_2, dim=0)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dg_A_positive[bb, bs * block_S + i_s1, bh, i_s2] = dA_A_fragment[i_s1, i_s2]
dg_A_negative[bb, bs * block_S + i_s2, bh, i_s1] = dA_A_fragment[i_s1, i_s2]
for i_s in T.Parallel(block_S):
dbeta_k[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s]
return kernel
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
K, V, Beta, G, A, dw, du = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang
kernel = tilelang_wy_fast_bwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
from test_utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
assert_similar(dg_ref, dg_tilelang, eps=1e-5, name="dg", raise_assert=False)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
import torch
import tilelang.testing
from tilelang import language as T
B = 1
S = 1024 # small but for test only.
H = 32
DK = 128
DV = 128
input_dtype = T.bfloat16
output_dtype = T.bfloat16
accum_dtype = T.float32
gate_dtype = T.float32
state_dtype = T.float32
chunk_size = 64
use_g = True
use_initial_state = True
store_final_state = True
use_final_state_gradient = True
save_new_value = True
block_DK = 64
block_DV = 32
threads = 128
num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input(
B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages,
)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang
kernel = tilelang_wy_fast_bwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
scale = 1.0 / DK**0.5
block_S = chunk_size
kernel = tilelang_chunk_fwd_o(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
True,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(
B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
)
A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=False,
head_first=False,
input_dtype=gate_dtype,
output_dtype=gate_dtype,
block_S=block_S,
threads=threads,
use_fragment=False,
)
G_new_tilelang = kernel(G) # noqa: F841
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
block_DK,
block_DV,
threads,
num_stages,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
if __name__ == "__main__":
tilelang.testing.main()
import torch
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert:
raise AssertionError
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print_red_warning(f"{name} Error: {diff}")
if raise_assert:
raise AssertionError
else:
print(f"{name} {data} passed")
# TileLang GEMM (Matrix Multiplication) Examples
TileLang is a domain-specific language designed to simplify the process of writing GPU kernels. It provides high-level abstractions for memory allocation, scheduling, and tiling, which are critical for achieving maximum performance on modern hardware architectures like NVIDIA GPUs. This README demonstrates how to write and optimize a matrix multiplication (GEMM) kernel using TileLang.
## Table of Contents
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
- [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
- [Verifying Correctness](#verifying-correctness)
- [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
- [References](#references)
---
## Getting Started
### Prerequisites
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation
```bash
pip install tilelang bitblas
```
*(Adjust accordingly if you are installing from source or using a different environment.)*
---
## Simple GEMM Example
Below is a basic matrix multiplication (GEMM) example demonstrating how TileLang handles buffer allocation, tiling, and kernel dispatch. For simplicity, we'll multiply two 1024×1024 matrices using 128 threads/block.
```python
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Define a grid with enough blocks to cover M×N
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared memory for the current tile of A and B
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
# Allocate a local (register) fragment for partial accumulations
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Initialize the local accumulation buffer to zero
T.clear(C_local)
# Loop over the K dimension in block_K chunks, using a 3-stage pipeline
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy from global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a matrix multiply-accumulate on the tile
T.gemm(A_shared, B_shared, C_local)
# Copy the accumulated result from local memory (C_local) to global memory (C)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
### Code Walkthrough
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
T.gemm(...)
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
Writes the final computed tile from registers/shared memory to global memory.
### Compiling and Profiling
```python
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func) # Prints an IR-like representation of the TileLang kernel
artifact = tilelang.lower(func)
profiler = Profiler(artifact.rt_mod, artifact.params, result_idx=[2])
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
ref_c = a @ b
# Validate results
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Kernel Source
print(artifact.kernel_source)
```
---
## Advanced GEMM Features
### Custom Memory Layout / Swizzling
**Swizzling** rearranges data in shared memory or global memory to mitigate bank conflicts, improve cache utilization, and better match the GPU’s warp execution pattern. TileLang provides helper functions like `make_swizzle_layout` to annotate how buffers should be laid out in memory.
### Parallel Copy and Auto-Pipelining
- **Parallel Copy** allows you to distribute the copy of a block tile across all threads in a block, speeding up the transfer from global memory to shared memory.
- **Auto-Pipelining** uses multiple stages to overlap copying with computation, reducing idle cycles.
### Rasterization for L2 Cache Locality
Enabling **swizzle (rasterization)** at the kernel level can improve data reuse and reduce cache thrashing in L2. This is especially important when matrices are large.
---
## Enhanced GEMM Example with Annotations
Below is a more advanced snippet that showcases how to apply memory layouts, enable swizzling, and parallelize the copy operations to maximize performance:
```python
import tilelang.language as T
# `make_mma_swizzle_layout` is a python-defined layout function
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared and local fragments
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Annotate memory layout
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Enable swizzle-based rasterization for better L2 locality
T.use_swizzle(panel_size=10, enable=True)
# Clear the local accumulation buffer
T.clear(C_local)
# Pipelined iteration over K dimension
for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, idx * block_K], A_shared)
# Parallel copy tile of B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j]
# Perform local GEMM on the shared-memory tiles
T.gemm(A_shared, B_shared, C_local)
# Copy the result tile back
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
---
## Verifying Correctness
Once you compile and load your kernel into a runtime module (`rt_mod`), you can use tools like **PyTorch** to easily create random matrices on the GPU, run your TileLang kernel, and compare the results to a reference implementation (e.g., `torch.matmul` or `@` operator).
```python
import torch
# Suppose your compiled kernel is in rt_mod
profiler = Profiler(rt_mod, params, result_idx=[2])
A = torch.randn(1024, 1024).cuda().half()
B = torch.randn(1024, 1024).cuda().half()
C_tilelang = profiler(A, B)
C_ref = A @ B
torch.testing.assert_close(C_tilelang, C_ref, rtol=1e-2, atol=1e-2)
print("Results match!")
```
---
## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow
```python
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
```
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
\]
where each thread in the warp calculates a small portion of the final tile. For instance:
```python
mma_emitter.mma(A_local, B_local, C_local)
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared)
```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.
### Summary
By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with manual thread bindings and memory allocations, you can replicate the control and performance of raw CUDA at the TileLang level. This approach is best suited for expert users who are comfortable with GPU warp-level programming, since it does require a deep understanding of hardware concurrency, memory hierarchies, and scheduling. However, the payoff can be significant for performance-critical paths, where every byte of bandwidth and every cycle of latency must be carefully orchestrated.
---
## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
# benchmark
profiler = kernel.get_profiler()
latency = profiler.do_bench(backend="cupti")
# latency = profiler.do_bench()
print(f"tilelang Latency: {latency}ms")
if __name__ == "__main__":
main()
import argparse
import itertools
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
import torch
def ref_program(A, B):
"""
Compute the matrix product of A and the transpose of B.
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False, topk=20):
"""
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
- block_M, block_N, block_K: tile sizes
- num_stages: pipeline staging (0 means no explicit staging)
- thread_num: total threads used for the block
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
When with_roller is False this returns the Cartesian product of a fixed set of candidate
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
Parameters:
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
otherwise use a predefined candidate grid.
topk (int): Maximum number of roller hints to request when with_roller is True.
Returns:
List[dict]: A list of configuration dictionaries as described above.
Raises:
ValueError: if with_roller is True but the roller returns no hints.
"""
if with_roller:
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
)
)
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
}
for c in _configs
]
return configs
def get_best_config(M, N, K, with_roller=False):
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = T.bfloat16
accum_dtype = T.float32
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1],
target="auto",
)
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
)
)
return autotuner.run(warmup=3, rep=20)
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
elif sm_version in {90}:
return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
else:
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
@tl.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm_autotune(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_autotune
def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False):
use_autotune = True
if use_autotune:
result = get_best_config(M, N, K, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = matmul(M, N, K, **config)
# benchmark
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_program)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")
print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}")
print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == T.float16 else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return gemm_intrinsics
def ref_program(A, B):
return A @ B.T
def main(M=4096, N=4096, K=4096):
in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
profiler = kernel.get_profiler()
latency = profiler.do_bench(profiler.func, warmup=25)
print(latency)
# Ensure that the latency is not None
assert latency is not None
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main(M=4096, N=4096, K=4096)
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang import disable_cache
disable_cache()
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMAC Wrapper to Auto Generate Code for MMAC
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mmac_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mmac_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_local)
# Perform STMatrix
mmac_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y,
i // micro_size_x,
i % micro_size_x,
j % micro_size_y,
]
return gemm_intrinsics
def ref_program(A, B):
return A @ B.T
def main():
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
profiler = kernel.get_profiler()
latency = profiler.do_bench(profiler.func, warmup=25)
print(latency)
print(kernel.get_kernel_source())
# Ensure that the latency is not None
assert latency is not None
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(10)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main
@tilelang.jit(out_idx=[-1])
def matmul_persistent(
M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True
):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 8
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
@T.prim_func
def main_persistent_primitive(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main_persistent_primitive if use_persistent_primitive else main
def ref_program(A, B):
return A @ B
def main(M=4096, N=4096, K=4096):
total_flops = 2 * M * N * K
BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 64
threads = 256
num_stages = 3
persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms")
print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops")
print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--M", type=int, default=8192, help="M dimension")
parser.add_argument("--N", type=int, default=8192, help="N dimension")
parser.add_argument("--K", type=int, default=8192, help="K dimension")
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 Cache Locality
T.use_swizzle(panel_size=10)
# Clear the local buffer
T.clear(C_local)
# Auto pipeline the computation
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
# Instead of using
# T.copy(B[k * block_K, bx * block_N], B_shared)
# we can also use Parallel to auto map the thread
# bindings and vectorize the copy operation.
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_schedule
def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_autotune
import example_gemm_intrinsics
import example_gemm_schedule
import example_gemm
def test_example_gemm_autotune():
# enable roller for fast tuning
example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True)
def test_example_gemm_intrinsics():
example_gemm_intrinsics.main(M=1024, N=1024, K=1024)
def test_example_gemm_schedule():
example_gemm_schedule.main()
def test_example_gemm():
example_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future.
\ No newline at end of file
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools
def ref_program(A, B):
return (A.half() @ B.half().T).to(dtype=torch.float32)
def manual_check_prog(C, C_ref):
torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)
def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]
def get_configs():
block_Ms = [32, 64, 128]
block_Ns = [32, 64, 128]
block_Ks = [64, 128]
num_stages = [0]
num_threads = [256]
k_packs = [1, 2]
gemm_types = ["ss", "rs"]
valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
}
)
return valid_configs
@tilelang.autotune(
configs=get_configs(), cache_input_tensors=True, ref_prog=ref_program, manual_check_prog=manual_check_prog, supply_prog=supply_prog
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = T.float8_e4m3fnuz
accum_dtype = T.float32
@T.prim_func
def gemm_fp8_rs(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
@T.prim_func
def gemm_fp8_ss(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
if gemm_type == "ss":
return gemm_fp8_ss
elif gemm_type == "rs":
return gemm_fp8_rs
else:
raise ValueError(f"Invalid gemm_type: {gemm_type}")
def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("passed~")
if __name__ == "__main__":
test_gemm_fp8(512, 512, 512)
import torch
import tilelang
import tilelang.language as T
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_fp8
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.randn(M, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device="cuda").to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype)
print(c)
print(ref_c)
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
def main():
test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
if __name__ == "__main__":
main()
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
update_interval = 128 // block_K if block_K < 128 else 1
@T.prim_func
def gemm_fp8_2xAcc(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
if (k + 1) % update_interval == 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
T.clear(C_local)
# Tail processing
if K_iters % update_interval != 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_fp8_2xAcc
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.rand(M, K, dtype=torch.float16, device="cuda")
a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device="cuda")
b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = a.float() @ b.float().T
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
def main():
test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
if __name__ == "__main__":
main()
import torch
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
T.float16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_fp8_intrinsic(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return gemm_fp8_intrinsic
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
C = profiler(A, B)
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype)
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def main():
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
main()
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm_v2(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True
num_stages = 2
threads = 256
for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]:
for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype
func = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
},
)
# jit_kernel.export_ptx("./dump.ptx")
# jit_kernel.export_sources("./dump.cu")
a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype)
c = jit_kernel(a, b)
ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float()
c = c.float()
diff = calc_diff(c, ref_c)
# assert diff < 1e-3, f"{diff}"
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}")
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms")
print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS")
import tilelang.testing
import example_tilelang_gemm_fp8_2xAcc
import example_tilelang_gemm_fp8_intrinsic
import example_tilelang_gemm_fp8
def test_example_tilelang_gemm_fp8_2xAcc():
example_tilelang_gemm_fp8_2xAcc.main()
def test_example_tilelang_gemm_fp8_intrinsic():
example_tilelang_gemm_fp8_intrinsic.main()
def test_example_tilelang_gemm_fp8():
example_tilelang_gemm_fp8.main()
if __name__ == "__main__":
tilelang.testing.main()
# TileLang SM100 Support (Preview)
This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality.
## Current Limitations (Manual Implementation Required)
### 1. Manual TCGEN5.MMA Management
Users must manually handle TCGEN5MMA operations using:
- `T.alloc_tmem()` - Allocate Tensor Memory
- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting
- Manual synchronization with mbarrier
### 2. Manual mbarrier Synchronization
TCGEN5MMA is asynchronous and requires explicit synchronization:
```python
mbar = T.alloc_barrier(1) # expect-arrive-count = 1
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0)
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
```
## Examples
### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Demonstrates TCGEN5MMA operations with:
- Tensor Memory allocation
- Manual mbarrier synchronization
- TCGEN5MMA gemm operations
### Traditional MMA Example (`gemm_mma.py`)
Shows standard MMA operations that work across architectures for comparison.
## Code Example
The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication:
```python
import torch
import tilelang
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((N, K), T.bfloat16),
C: T.Tensor((M, N), T.bfloat16),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage
C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory
# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
# Data loading: global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True,
mbar=mbar, wg_wait=-1, clear_accum=k==0)
# Critical: wait for TCGEN5MMA completion
T.mbarrier_wait_parity(mbar, k%2)
# 3. Output processing (only subset of threads)
T.copy(C_tmem, C_local) # Tensor Memory → registers
T.copy(C_local, C_shared) # registers → shared memory
# 4. Write back to global memory
T.copy(C_shared, C[by * block_M, bx * block_N])
```
### Compilation and Usage
```python
# Parameter setup
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment