Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -9,7 +9,7 @@ import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
......@@ -81,10 +81,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
T.annotate_layout(
{
O_shared_l: tilelang.layout.make_swizzled_layout(O_shared_l),
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
})
}
)
# barriers_Q
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
......@@ -108,9 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(q_shared_ready_barrier)
T.barrier_wait(q_shared_ready_barrier, 0)
......@@ -123,25 +125,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.fill(acc_o_l, 0)
T.fill(logsum_0, 0)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(KV[bid, block_N:2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.copy(KV[bid, block_N : 2 * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.copy(K_pe[bid, block_N:2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.copy(K_pe[bid, block_N : 2 * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
for k in T.serial(loop_range):
T.barrier_wait(kv_shared_0_l_is_ready, k % 2)
T.gemm(
Q_shared_l,
KV_shared_0_l,
acc_s_0,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True, clear_accum=True, wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)
......@@ -161,8 +156,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, block_N):
acc_s_0[i, j] = T.exp2(acc_s_0[i, j] * scale - scores_max[i] * scale)
for i in T.Parallel(block_H):
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale -
scores_max[i] * scale)
scores_scale_0[i] = T.exp2(scores_max_prev_0[i] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
......@@ -182,9 +176,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(scale_1_ready_barrier, k % 2)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N,
cur_kv_head, :h_dim], KV_shared_0_l)
T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :h_dim], KV_shared_0_l)
T.barrier_arrive(kv_shared_0_l_is_ready)
# Step 11.
......@@ -204,15 +196,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
cur_kv_head, :h_dim], KV_shared_1_l)
T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready)
T.copy(
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
K_pe_shared_1)
T.copy(K_pe[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, :], K_pe_shared_1)
T.barrier_arrive(kv_shared_1_pe_is_ready)
T.copy(logsum_0, logsum)
......@@ -221,8 +208,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] /= logsum[i]
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid,
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :h_dim])
else:
T.copy(Q_pe_shared, Q_pe_local_1)
......@@ -237,16 +223,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(kv_shared_0_pe_is_ready)
for k in T.serial(loop_range):
# Step 2.
T.barrier_wait(kv_shared_1_l_is_ready, k % 2)
T.gemm(
Q_shared_l,
KV_shared_1_l,
acc_s_1,
transpose_B=True,
clear_accum=True,
wg_wait=-1)
T.gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True, clear_accum=True, wg_wait=-1)
T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)
......@@ -265,8 +244,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max_1, scores_max)
for i in T.Parallel(block_H):
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale -
scores_max[i] * scale)
scores_scale_1[i] = T.exp2(scores_max_prev_1[i] * scale - scores_max[i] * scale)
# Step 8.
for i, j in T.Parallel(block_H, block_N):
......@@ -279,8 +257,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r[i, j] = acc_o_r[i, j] * (scores_scale_0[i] * scores_scale_1[i])
for i in T.Parallel(block_H):
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[
i] + scores_sum_1[i]
logsum_1[i] = logsum_1[i] * scores_scale_1[i] * scores_scale_0[i] + scores_sum_1[i]
T.barrier_arrive(scale_1_ready_barrier)
......@@ -291,9 +268,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(s_shared_ready_barrier)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head,
h_dim:], KV_shared_1_r)
T.copy(KV[bid, (2 * k + 3) * block_N : (2 * k + 4) * block_N, cur_kv_head, h_dim:], KV_shared_1_r)
T.barrier_arrive(kv_shared_1_r_is_ready)
T.barrier_wait(p0_1_1_ready_barrier, k % 2)
......@@ -301,15 +276,10 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
if k < loop_range - 1:
T.copy(
KV[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head,
h_dim:], KV_shared_0_r)
T.copy(KV[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready)
T.copy(
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
K_pe_shared_0)
T.copy(K_pe[bid, (2 * k + 2) * block_N : (2 * k + 3) * block_N, cur_kv_head, :], K_pe_shared_0)
T.barrier_arrive(kv_shared_0_pe_is_ready)
T.barrier_wait(lse_0_ready_barrier, 0)
......@@ -319,8 +289,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, h_dim):
acc_o_r[i, j] /= logsum[i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
h_dim:])
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, h_dim:])
@T.prim_func
def main_no_split(
......@@ -352,31 +321,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -399,12 +361,12 @@ def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
......@@ -8,7 +8,6 @@ tilelang.disable_cache()
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
......@@ -32,19 +31,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in range(T.ceildiv(K, block_K)):
with T.ws(1):
T.mbarrier_wait_parity(
mbarrier=ko % num_stages + num_stages,
parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M:(by + 1) * block_M, ko * block_K:(ko + 1) * block_K],
A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K:(ko + 1) * block_K, bx * block_N:(bx + 1) * block_N],
B_shared[ko % num_stages, :, :])
T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(mbarrier=ko % num_stages)
with T.ws(0):
T.mbarrier_wait_parity(
mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :],
C_local)
T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local)
T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages)
with T.ws(0):
......
......@@ -5,15 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_0_gemm_1(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -5,15 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -5,18 +5,12 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(
out_idx=[2], pass_configs={
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
},
)
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
warp_group_num = 2
threads = 128 * warp_group_num
......
......@@ -6,7 +6,6 @@ import tilelang.language as T
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor[(M, K), dtype],
......
......@@ -9,7 +9,7 @@
# bash format.sh --all
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# Ruff (format) + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
......
......@@ -66,7 +66,8 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print(kernel.get_kernel_source())
......@@ -394,37 +395,48 @@ M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
FALSE_TRUE_CASES = (
[
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
] + [pytest.param(
)
for k in K_VALUES
]
+ [
pytest.param(
k,
"int8",
"int32",
"int32",
id="K32-int8-int32-int32",
) for k in K_VALUES_8Bit] + [
)
for k in K_VALUES_8Bit
]
+ [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
] + [
)
for k in K_VALUES_8Bit
]
+ [
pytest.param(
k,
"float8_e4m3",
"float32",
"float32",
id="K32-float8_e4m3-float32-float32",
) for k in K_VALUES_8Bit
])
)
for k in K_VALUES_8Bit
]
)
def _ensure_torch_dtypes(*dtype_names):
......
......@@ -67,7 +67,8 @@ def _compile_and_check(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TIR_USE_ASYNC_COPY: False,
})
},
)
print(kernel.get_kernel_source())
......@@ -213,14 +214,15 @@ def run_gemm_rs(
M_VALUES = [64, 128]
N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float16",
"float16",
id=f"K{k}-float16-float16-float16",
) for k in K_VALUES
)
for k in K_VALUES
] + [
pytest.param(
k,
......@@ -228,8 +230,9 @@ FALSE_TRUE_CASES = ([
"float16",
"float32",
id=f"K{k}-float16-float16-float32",
) for k in K_VALUES
])
)
for k in K_VALUES
]
def _ensure_torch_dtypes(*dtype_names):
......
......@@ -42,15 +42,7 @@ def matmul(
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
......@@ -74,7 +66,8 @@ def _compile_and_check(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
print(kernel.get_kernel_source())
......@@ -138,14 +131,15 @@ M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float32",
"float32",
id=f"K{k}-float16-float-float",
) for k in K_VALUES
)
for k in K_VALUES
] + [
pytest.param(
k,
......@@ -153,8 +147,9 @@ FALSE_TRUE_CASES = ([
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
])
)
for k in K_VALUES_8Bit
]
TRANS_CASES = [
pytest.param(False, True, id="nt"),
......
......@@ -14,7 +14,6 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -14,7 +14,6 @@ use_v2 = args.use_v2
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -8,13 +8,13 @@ import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=16, help="heads")
parser.add_argument("--seq_q", type=int, default=1024, help="query sequence length")
parser.add_argument("--seq_kv", type=int, default=1024, help="key/value sequence length")
parser.add_argument("--dim", type=int, default=256, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
......@@ -29,20 +29,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
},
)
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
......@@ -62,7 +55,7 @@ def flashattn(batch,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
......@@ -85,7 +78,7 @@ def flashattn(batch,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
......@@ -152,43 +145,42 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
if is_causal
else T.ceildiv(seq_kv, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output
......@@ -206,18 +198,8 @@ def main(
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
if not tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
......
......@@ -3,6 +3,7 @@
Note: The adapter-level wrapper expects only inputs (A, B) because C is marked as output.
Calling with the wrong number of inputs raises a ValueError before host entry.
"""
import torch
from common import build_matmul_kernel
......
......@@ -3,6 +3,7 @@
We pass an integer for A; wrapper forwards it to the host where a pointer is expected.
Expected: error like "Expect buffer A_handle to be pointer or tensor" (exact name depends on kernel param).
"""
import torch
from common import build_matmul_kernel
......
"""Reproduce: ndim (rank) mismatch for A.
"""
"""Reproduce: ndim (rank) mismatch for A."""
import torch
from common import build_matmul_kernel
......
"""Reproduce: dtype mismatch for A (float32 vs expected float16).
"""
"""Reproduce: dtype mismatch for A (float32 vs expected float16)."""
import torch
from common import build_matmul_kernel
......
"""Reproduce: shape constant/symbol mismatch on A.
"""
"""Reproduce: shape constant/symbol mismatch on A."""
import torch
from common import build_matmul_kernel
......
"""Reproduce: strides check failure (non-contiguous A via transpose).
"""
"""Reproduce: strides check failure (non-contiguous A via transpose)."""
import torch
from common import build_matmul_kernel
......
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel.
"""
"""Reproduce: device_type mismatch by passing CPU tensors to a CUDA kernel."""
import torch
from common import build_matmul_kernel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment