Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
from tilelang import language as T
import tilelang.testing
......@@ -24,32 +25,24 @@ def matmul_sp(
A_shared_shape = (block_M, block_K // 2)
B_shared_shape = (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // 8), 'uint8'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // 8), "uint8"),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
E_shared = T.alloc_shared((block_M, block_K // 8), "uint8")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="9.0",
backend="cutlass",
block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K),
}
)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // 8], E_shared)
......@@ -61,7 +54,7 @@ def matmul_sp(
return main
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device="cpu"):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
......@@ -106,9 +99,9 @@ def run_gemm_sp(
num_threads,
)
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda')
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device="cuda")
A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
C_sp = kernel(A_sparse, E, B).half()
C = torch.matmul(A, B)
......@@ -117,7 +110,7 @@ def run_gemm_sp(
def main():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128)
run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128)
if __name__ == "__main__":
......
......@@ -22,19 +22,19 @@ def tl_topk(
blk_m,
threads=128,
):
dtype = "float32"
dtype = T.float32
@T.prim_func
def topk_kernel(
logits: T.Tensor([M, N], dtype),
topk_gates: T.Tensor([M, topk], dtype),
topk_indices: T.Tensor([M, topk], "int32"),
logits: T.Tensor([M, N], dtype),
topk_gates: T.Tensor([M, topk], dtype),
topk_indices: T.Tensor([M, topk], T.int32),
):
with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx:
logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype)
max_val = T.alloc_fragment([blk_m], dtype=dtype)
expand_max_idx = T.alloc_fragment([blk_m, N], "int32")
max_idx = T.alloc_fragment([blk_m], "int32")
expand_max_idx = T.alloc_fragment([blk_m, N], T.int32)
max_idx = T.alloc_fragment([blk_m], T.int32)
T.copy(logits[bx * blk_m, 0], logits_frag)
......@@ -43,15 +43,12 @@ def tl_topk(
T.reduce_max(logits_frag, max_val, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N):
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j,
expand_max_idx[i, j])
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j])
T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True)
for i, j in T.Parallel(blk_m, N):
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0,
logits_frag[i, j])
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j])
for i in T.Parallel(blk_m):
topk_gates[bx * blk_m + i, k] = max_val[i]
......@@ -61,7 +58,6 @@ def tl_topk(
def ref_program(logits, top_k):
top_k_gates, top_k_indices = logits.topk(top_k, dim=1)
return top_k_gates, top_k_indices.to(torch.int32)
......
import tilelang
import tilelang.language as T
# use pass_configs to enable layout visualization
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True,
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg",
},
)
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def main():
kernel = matmul(128, 128, 128, 32, 32, 32)
import torch
a = torch.randn(128, 128).cuda().half()
b = torch.randn(128, 128).cuda().half()
c = kernel(a, b)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# print the layout visualization result and save figures to ./tmp.
"""
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
"""
if __name__ == "__main__":
main()
......@@ -9,9 +9,9 @@ import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
# smem_sQ
......@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
})
T.annotate_layout(
{
O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
}
)
# barriers_Q
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
......@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(q_shared_ready_barrier)
T.barrier_wait(q_shared_ready_barrier, 0)
......@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o_l, 0)
T.fill(logsum_0, 0)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
for k in T.serial(loop_range):
T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
T.gemm(
Q_shared_l,
KV_shared_0_l,
acc_s_0,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
......@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, block_N):
acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
for i in T.Parallel(block_H):
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale -
scores_max[i] * scale)
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
......@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(scale_1_ready_barrier, k % 2)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
cur_kv_head, :h_dim], KV_shared_0_l)
T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l)
T.barrier_arrive(kv_shared_0_l_is_ready)
# Step 11.
......@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
cur_kv_head, :h_dim], KV_shared_1_l)
T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
K_pe_shared_1)
T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
T.copy(logsum_0, logsum)
......@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] /= logsum[i]
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid,
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim])
else:
T.copy(Q_pe_shared, Q_pe_local_1)
......@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(kv_shared_0_pe_is_ready)
for k in T.serial(loop_range):
# Step 2.
T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
T.gemm(
Q_shared_l,
KV_shared_1_l,
acc_s_1,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
......@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max_1, scores_max)
for i in T.Parallel(block_H):
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale -
scores_max[i] * scale)
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale)
# Step 8.
for i, j in T.Parallel(block_H, block_N):
......@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])
for i in T.Parallel(block_H):
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[
i] + scores_sum_1[i]
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i]
T.barrier_arrive(scale_1_ready_barrier)
......@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(s_shared_ready_barrier)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
h_dim:], KV_shared_1_r)
T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.barrier_wait(p0_1_1_ready_barrier, k % 2)
......@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
h_dim:], KV_shared_0_r)
T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready)
T.copy(
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
K_pe_shared_0)
T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0)
T.barrier_arrive(kv_shared_0_pe_is_ready)
T.barrier_wait(lse_0_ready_barrier, 0)
......@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim):
acc_o_r[i, j] /= logsum[i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
h_dim:])
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:])
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
......@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
......@@ -7,8 +7,7 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
......@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in range(T.ceildiv(K, block_K)):
with T.ws(1):
T.mbarrier_wait_parity(
mbarrier=ko % num_stages + num_stages,
parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :])
T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(mbarrier=ko % num_stages)
with T.ws(0):
T.mbarrier_wait_parity(
mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local)
T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local)
T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages)
with T.ws(0):
......
......@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_0_gemm_1(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......
......@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......
......@@ -5,26 +5,20 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(
out_idx=[2], pass_configs={
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
},
)
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
warp_group_num = 2
threads = 128 * warp_group_num
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
......@@ -5,8 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor[(M, K), dtype],
......
......@@ -9,7 +9,7 @@
# bash format.sh --all
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
......
......@@ -2,6 +2,8 @@
import pytest
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
import torch
def matmul(
......@@ -24,13 +26,11 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -66,20 +66,19 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -147,13 +146,11 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -234,13 +231,11 @@ def matmul_sr(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -322,13 +317,11 @@ def matmul_rr(
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -394,37 +387,48 @@ M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [pytest.param(
k,
"int8",
"int32",
"int32",
id="K32-int8-int32-int32",
) for k in K_VALUES_8Bit] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
] + [
pytest.param(
k,
"float8_e4m3",
"float32",
"float32",
id="K32-float8_e4m3-float32-float32",
) for k in K_VALUES_8Bit
])
FALSE_TRUE_CASES = (
[
pytest.param(
k,
T.float16,
T.float16,
T.float16,
id=f"K{k}-float16-float16-float16",
)
for k in K_VALUES
]
+ [
pytest.param(
k,
T.int8,
T.int32,
T.int32,
id="K32-int8-int32-int32",
)
for k in K_VALUES_8Bit
]
+ [
pytest.param(
k,
T.float8_e5m2,
T.float32,
T.float32,
id="K32-float8_e5m2-float32-float32",
)
for k in K_VALUES_8Bit
]
+ [
pytest.param(
k,
T.float8_e4m3fn,
T.float32,
T.float32,
id="K32-float8_e4m3-float32-float32",
)
for k in K_VALUES_8Bit
]
)
def _ensure_torch_dtypes(*dtype_names):
......@@ -440,15 +444,15 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rs_true_false(m, n, k):
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rs_true_true(m, n, k):
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
......@@ -456,15 +460,15 @@ def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_sr_false_false(m, n, k):
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_true_false(m, n, k):
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_true_true(m, n, k):
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
......@@ -472,15 +476,15 @@ def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rr_false_false(m, n, k):
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_true_false(m, n, k):
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_true_true(m, n, k):
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
TRANS_CASES = [
......@@ -536,9 +540,9 @@ def test_gemm_false_false(m, n, k):
k * 3,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -555,9 +559,9 @@ def test_gemm_true_false(m, n, k):
k * 3,
True,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -574,9 +578,9 @@ def test_gemm_true_true(m, n, k):
k * 3,
True,
True,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -595,7 +599,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_false_false(m, n, k)
......@@ -603,7 +607,7 @@ def test_gemm_rs_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_true_false(m, n, k)
......@@ -611,7 +615,7 @@ def test_gemm_rs_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_true_true(m, n, k)
......@@ -627,7 +631,7 @@ def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_false_false(m, n, k)
......@@ -635,7 +639,7 @@ def test_gemm_sr_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_true_false(m, n, k)
......@@ -643,7 +647,7 @@ def test_gemm_sr_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_true_true(m, n, k)
......@@ -659,7 +663,7 @@ def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_false_false(m, n, k)
......@@ -667,7 +671,7 @@ def test_gemm_rr_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_true_false(m, n, k)
......@@ -675,7 +679,7 @@ def test_gemm_rr_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_true_true(m, n, k)
......@@ -687,7 +691,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -695,7 +699,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -703,7 +707,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True False =============================")
# run_gemm(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
......@@ -712,7 +716,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True True =============================")
# run_gemm(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
......@@ -721,15 +725,15 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
......@@ -2,6 +2,7 @@
import pytest
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
def matmul(
......@@ -24,13 +25,11 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -67,7 +66,8 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print(kernel.get_kernel_source())
......@@ -80,7 +80,7 @@ def _compile_and_check(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -146,13 +146,11 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
......@@ -213,23 +211,25 @@ def run_gemm_rs(
M_VALUES = [64, 128]
N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
)
for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
)
for k in K_VALUES
]
def _ensure_torch_dtypes(*dtype_names):
......@@ -245,7 +245,7 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
TRANS_CASES = [
......@@ -303,9 +303,9 @@ def test_gemm_false_false(m, n, k):
k * 3,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -326,7 +326,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_false_false(m, n, k)
......@@ -338,7 +338,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -346,5 +346,5 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
......@@ -27,9 +27,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -42,15 +42,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
......@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(kernel.get_kernel_source())
......@@ -87,7 +80,7 @@ def _compile_and_check(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -138,23 +131,25 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float32",
"float32",
T.float16,
T.float32,
T.float32,
id=f"K{k}-float16-float-float",
) for k in K_VALUES
)
for k in K_VALUES
] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
T.float8_e5m2,
T.float32,
T.float32,
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
])
)
for k in K_VALUES_8Bit
]
TRANS_CASES = [
pytest.param(False, True, id="nt"),
......@@ -191,7 +186,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
if __name__ == "__main__":
# tilelang.testing.main()
tilelang.testing.main()
# # Test Pass
# for m in [32, 64, 128, 256]:
......@@ -200,27 +195,24 @@ if __name__ == "__main__":
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# for n in [32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256)
# print(f"Test {m} {n} {k} Pass")
tilelang.disable_cache()
run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128)
run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128)
......@@ -13,13 +13,12 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -13,13 +13,12 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -8,13 +8,13 @@ import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=16, help="heads")
parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length")
parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length")
parser.add_argument("--dim", type=int, default=256, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
......@@ -29,24 +29,17 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
},
)
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......@@ -62,7 +55,7 @@ def flashattn(batch,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
......@@ -85,7 +78,7 @@ def flashattn(batch,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
......@@ -94,13 +87,13 @@ def flashattn(batch,
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -125,18 +118,18 @@ def flashattn(batch,
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
......@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
if is_causal
else T.ceildiv(seq_kv, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output
......@@ -206,18 +198,8 @@ def main(
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
if not tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
......
"""Reproduce: Argument count mismatch.
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 256
fn = build_matmul_kernel(M, N, K, target="cuda")
a = torch.empty((M, K), device="cuda", dtype=torch.float16)
# Missing b
# Expected: ValueError with message about expected vs. actual inputs
fn(a)
if __name__ == "__main__":
main()
"""Reproduce: Pointer-type argument expected but scalar provided.
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 256
fn = build_matmul_kernel(M, N, K, target="cuda")
# Wrong type for A (int instead of tensor)
a = 1
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: ndim (rank) mismatch for A."""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
# A has rank 3 instead of 2
a = torch.empty((M, K, 1), device="cuda", dtype=torch.float16)
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
"""Reproduce: dtype mismatch for A (float32 vs expected float16)."""
import torch
from common import build_matmul_kernel
def main():
M = N = K = 128
fn = build_matmul_kernel(M, N, K, target="cuda")
print(fn.get_host_source())
a = torch.empty((M, K), device="cuda", dtype=torch.float32) # should be float16
b = torch.empty((K, N), device="cuda", dtype=torch.float16)
fn(a, b)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment