Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -9,7 +9,7 @@ import argparse ...@@ -9,7 +9,7 @@ import argparse
@tilelang.jit(out_idx=[6]) @tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
...@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -19,11 +19,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
# smem_sQ # smem_sQ
...@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H) cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({ T.annotate_layout(
O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l), {
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
}) O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
}
)
# barriers_Q # barriers_Q
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256) q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
...@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx = T.get_thread_binding() tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l) T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r) T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(q_shared_ready_barrier) T.barrier_arrive(q_shared_ready_barrier)
T.barrier_wait(q_shared_ready_barrier, 0) T.barrier_wait(q_shared_ready_barrier, 0)
...@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o_l, 0) T.fill(acc_o_l, 0)
T.fill(logsum_0, 0) T.fill(logsum_0, 0)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l) T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready) T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r) T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready) T.barrier_arrive(kv_shared_1_r_is_ready)
T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1) T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready) T.barrier_arrive(kv_shared_1_pe_is_ready)
for k in T.serial(loop_range): for k in T.serial(loop_range):
T.barrier_wait(kv_shared_0_l_is_ready, k % 2) T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
T.gemm( T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1)
Q_shared_l,
KV_shared_0_l,
acc_s_0,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
...@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale) acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
T.reduce_sum(acc_s_0, scores_sum_0, dim=1) T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
...@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(scale_1_ready_barrier, k % 2) T.barrier_wait(scale_1_ready_barrier, k % 2)
if k < loop_range - 1: if k < loop_range - 1:
T.copy( T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l)
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
cur_kv_head, :h_dim], KV_shared_0_l)
T.barrier_arrive(kv_shared_0_l_is_ready) T.barrier_arrive(kv_shared_0_l_is_ready)
# Step 11. # Step 11.
...@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP1_shared, KV_shared_1_l, acc_o_l) T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
if k < loop_range - 1: if k < loop_range - 1:
T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready) T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy( T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1)
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready) T.barrier_arrive(kv_shared_1_pe_is_ready)
T.copy(logsum_0, logsum) T.copy(logsum_0, logsum)
...@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim): for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] /= logsum[i] acc_o_l[i, j] /= logsum[i]
T.copy(acc_o_l, O_shared_l) T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid, T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim])
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
else: else:
T.copy(Q_pe_shared, Q_pe_local_1) T.copy(Q_pe_shared, Q_pe_local_1)
...@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(kv_shared_0_pe_is_ready) T.barrier_arrive(kv_shared_0_pe_is_ready)
for k in T.serial(loop_range): for k in T.serial(loop_range):
# Step 2. # Step 2.
T.barrier_wait(kv_shared_1_l_is_ready, k % 2) T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
T.gemm( T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1)
Q_shared_l,
KV_shared_1_l,
acc_s_1,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2) T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
...@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max_1, scores_max) T.copy(scores_max_1, scores_max)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale)
scores_max[i] * scale)
# Step 8. # Step 8.
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i]) acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[ logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i]
i] + scores_sum_1[i]
T.barrier_arrive(scale_1_ready_barrier) T.barrier_arrive(scale_1_ready_barrier)
...@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(s_shared_ready_barrier) T.barrier_arrive(s_shared_ready_barrier)
if k < loop_range - 1: if k < loop_range - 1:
T.copy( T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready) T.barrier_arrive(kv_shared_1_r_is_ready)
T.barrier_wait(p0_1_1_ready_barrier, k % 2) T.barrier_wait(p0_1_1_ready_barrier, k % 2)
...@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP0_shared, KV_shared_0_r, acc_o_r) T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
if k < loop_range - 1: if k < loop_range - 1:
T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r)
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready) T.barrier_arrive(kv_shared_0_r_is_ready)
T.copy( T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0)
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
K_pe_shared_0)
T.barrier_arrive(kv_shared_0_pe_is_ready) T.barrier_arrive(kv_shared_0_pe_is_ready)
T.barrier_wait(lse_0_ready_barrier, 0) T.barrier_wait(lse_0_ready_barrier, 0)
...@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -319,18 +289,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim): for i, j in T.Parallel(block_H, h_dim):
acc_o_r[i, j] /= logsum[i] acc_o_r[i, j] /= logsum[i]
T.copy(acc_o_r, O_shared_r) T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:])
h_dim:])
@T.prim_func @T.prim_func
def main_no_split( def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype), Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
flash_attn(Q, Q_pe, KV, K_pe, Output) flash_attn(Q, Q_pe, KV, K_pe, Output)
...@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): ...@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1] dim = q.shape[-1]
pe_dim = q_pe.shape[-1] pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2] num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5 scale = (dim + pe_dim) ** 0.5
q = rearrange( q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange( q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim] k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1) query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1) key = torch.concat([kv, k_pe], dim=-1)
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
...@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): ...@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=128, help='q heads number') parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args() args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim) main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
...@@ -8,7 +8,6 @@ tilelang.disable_cache() ...@@ -8,7 +8,6 @@ tilelang.disable_cache()
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2 num_stages = 2
mbarrier_list = [128, 128] * num_stages mbarrier_list = [128, 128] * num_stages
...@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in range(T.ceildiv(K, block_K)): for ko in range(T.ceildiv(K, block_K)):
with T.ws(1): with T.ws(1):
T.mbarrier_wait_parity( T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1)
mbarrier=ko % num_stages + num_stages, T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :])
parity=((ko // num_stages) % num_stages) ^ 1) T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :])
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(mbarrier=ko % num_stages) T.mbarrier_arrive(mbarrier=ko % num_stages)
with T.ws(0): with T.ws(0):
T.mbarrier_wait_parity( T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages) T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local)
T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages) T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages)
with T.ws(0): with T.ws(0):
......
...@@ -5,20 +5,12 @@ import tilelang.language as T ...@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_0_gemm_1(M, def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......
...@@ -5,20 +5,12 @@ import tilelang.language as T ...@@ -5,20 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_1_gemm_0(M, def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
......
...@@ -5,26 +5,20 @@ import tilelang.language as T ...@@ -5,26 +5,20 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}) },
def matmul_warp_specialize_copy_1_gemm_0(M, )
N, def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
warp_group_num = 2 warp_group_num = 2
threads = 128 * warp_group_num threads = 128 * warp_group_num
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -6,7 +6,6 @@ import tilelang.language as T ...@@ -6,7 +6,6 @@ import tilelang.language as T
# @tilelang.jit # @tilelang.jit
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M, K), dtype], A: T.Tensor[(M, K), dtype],
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# bash format.sh --all # bash format.sh --all
# #
# #
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. # Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review. # You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails # Cause the script to exit if a single command fails
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -66,7 +66,8 @@ def _compile_and_check( ...@@ -66,7 +66,8 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
}) },
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
...@@ -151,9 +152,9 @@ def matmul_rs( ...@@ -151,9 +152,9 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -238,9 +239,9 @@ def matmul_sr( ...@@ -238,9 +239,9 @@ def matmul_sr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -326,9 +327,9 @@ def matmul_rr( ...@@ -326,9 +327,9 @@ def matmul_rr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256] ...@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128, 256, 512] N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128] K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128] K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([ FALSE_TRUE_CASES = (
pytest.param( [
k, pytest.param(
"float16", k,
"float16", "float16",
"float16", "float16",
id=f"K{k}-float16-float16-float16", "float16",
) for k in K_VALUES id=f"K{k}-float16-float16-float16",
] + [pytest.param( )
k, for k in K_VALUES
"int8", ]
"int32", + [
"int32", pytest.param(
id="K32-int8-int32-int32", k,
) for k in K_VALUES_8Bit] + [ "int8",
pytest.param( "int32",
k, "int32",
"float8_e5m2", id="K32-int8-int32-int32",
"float32", )
"float32", for k in K_VALUES_8Bit
id="K32-float8_e5m2-float32-float32", ]
) for k in K_VALUES_8Bit + [
] + [ pytest.param(
pytest.param( k,
k, "float8_e5m2",
"float8_e4m3", "float32",
"float32", "float32",
"float32", id="K32-float8_e5m2-float32-float32",
id="K32-float8_e4m3-float32-float32", )
) for k in K_VALUES_8Bit for k in K_VALUES_8Bit
]) ]
+ [
pytest.param(
k,
"float8_e4m3",
"float32",
"float32",
id="K32-float8_e4m3-float32-float32",
)
for k in K_VALUES_8Bit
]
)
def _ensure_torch_dtypes(*dtype_names): def _ensure_torch_dtypes(*dtype_names):
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -67,7 +67,8 @@ def _compile_and_check( ...@@ -67,7 +67,8 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False, # tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
}) },
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
...@@ -150,9 +151,9 @@ def matmul_rs( ...@@ -150,9 +151,9 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared.dyn")
...@@ -213,14 +214,15 @@ def run_gemm_rs( ...@@ -213,14 +214,15 @@ def run_gemm_rs(
M_VALUES = [64, 128] M_VALUES = [64, 128]
N_VALUES = [32, 64, 128] N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64] K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([ FALSE_TRUE_CASES = [
pytest.param( pytest.param(
k, k,
"float16", "float16",
"float16", "float16",
"float16", "float16",
id=f"K{k}-float16-float16-float16", id=f"K{k}-float16-float16-float16",
) for k in K_VALUES )
for k in K_VALUES
] + [ ] + [
pytest.param( pytest.param(
k, k,
...@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([ ...@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
"float16", "float16",
"float32", "float32",
id=f"K{k}-float16-float16-float32", id=f"K{k}-float16-float16-float32",
) for k in K_VALUES )
]) for k in K_VALUES
]
def _ensure_torch_dtypes(*dtype_names): def _ensure_torch_dtypes(*dtype_names):
......
...@@ -27,9 +27,9 @@ def matmul( ...@@ -27,9 +27,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -42,15 +42,7 @@ def matmul( ...@@ -42,15 +42,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm( T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2) T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local) T.copy(C_tmem, C_local)
...@@ -74,7 +66,8 @@ def _compile_and_check( ...@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
...@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256] ...@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512] N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128] K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128] K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([ FALSE_TRUE_CASES = [
pytest.param( pytest.param(
k, k,
"float16", "float16",
"float32", "float32",
"float32", "float32",
id=f"K{k}-float16-float-float", id=f"K{k}-float16-float-float",
) for k in K_VALUES )
for k in K_VALUES
] + [ ] + [
pytest.param( pytest.param(
k, k,
...@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([ ...@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
"float32", "float32",
"float32", "float32",
id="K32-float8_e5m2-float32-float32", id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit )
]) for k in K_VALUES_8Bit
]
TRANS_CASES = [ TRANS_CASES = [
pytest.param(False, True, id="nt"), pytest.param(False, True, id="nt"),
......
...@@ -14,12 +14,11 @@ use_v2 = args.use_v2 ...@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time # if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def matmul_relu_kernel( def matmul_relu_kernel(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -14,12 +14,11 @@ use_v2 = args.use_v2 ...@@ -14,12 +14,11 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time # if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def matmul_relu_kernel( def matmul_relu_kernel(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -8,13 +8,13 @@ import argparse ...@@ -8,13 +8,13 @@ import argparse
from functools import partial from functools import partial
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size') parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument('--heads', type=int, default=16, help='heads') parser.add_argument("--heads", type=int, default=16, help="heads")
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length")
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length")
parser.add_argument('--dim', type=int, default=256, help='dim') parser.add_argument("--dim", type=int, default=256, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument("--use_v2", action="store_true") parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args() args = parser.parse_args()
...@@ -29,20 +29,13 @@ def get_configs(): ...@@ -29,20 +29,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128):
seq_q, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16" dtype = "float16"
...@@ -62,7 +55,7 @@ def flashattn(batch, ...@@ -62,7 +55,7 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
...@@ -85,7 +78,7 @@ def flashattn(batch, ...@@ -85,7 +78,7 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2: if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
...@@ -94,13 +87,13 @@ def flashattn(batch, ...@@ -94,13 +87,13 @@ def flashattn(batch,
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -125,18 +118,18 @@ def flashattn(batch, ...@@ -125,18 +118,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -152,43 +145,42 @@ def flashattn(batch, ...@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min( T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv( if is_causal
(bx + 1) * block_M + else T.ceildiv(seq_kv, block_N)
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_q = Q.size(2) seq_q = Q.size(2)
seq_kv = K.size(2) seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output return output
...@@ -206,18 +198,8 @@ def main( ...@@ -206,18 +198,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output. Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry. Calling with the wrong number of inputs raises a ValueError before host entry.
""" """
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
We pass an integer for A; wrapper forwards it to the host where a pointer is expected. We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param). Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
""" """
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: ndim (rank) mismatch for A. """Reproduce: ndim (rank) mismatch for A."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: dtype mismatch for A (float32 vs expected float16). """Reproduce: dtype mismatch for A (float32 vs expected float16)."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: shape constant/symbol mismatch on A. """Reproduce: shape constant/symbol mismatch on A."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: strides check failure (non-contiguous A via transpose). """Reproduce: strides check failure (non-contiguous A via transpose)."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel. """Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel."""
"""
import torch import torch
from common import build_matmul_kernel from common import build_matmul_kernel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment