Unverified Commit a1149cab authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Optimize warp specialize flashmla example (#698)

* [Enhancement] Disable cache and append git commit ID to version in tilelang (#688)

* Disabled caching in quickstart example for improved performance.
* Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities.

* revert quickstart

* optimize code.
parent ed1b96d5
...@@ -28,39 +28,58 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -28,39 +28,58 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
# smem_sQ
Q_shared_l = T.alloc_shared([block_H, h_dim], dtype) Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
Q_shared_r = T.alloc_shared([block_H, h_dim], dtype) Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype)
Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype)
# smem_sK0
KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype) KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype) KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
# smem_sK1
KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype) KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype) KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype) K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
# smem_sP0
SP0_shared = T.alloc_shared([block_H, block_N], dtype)
# smem_sP1 reuse Q_pe_shared
SP1_shared = Q_pe_shared
# smem_sM
scores_max = T.alloc_shared([block_H], accum_dtype)
# smem_sScale0
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
# smem_sScale1
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
logsum = T.alloc_shared([block_H], accum_dtype)
O_shared_l = Q_shared_l O_shared_l = Q_shared_l
O_shared_r = Q_shared_r O_shared_r = Q_shared_r
S_shared = K_pe_shared_0
S_shared_ = K_pe_shared_1
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype) acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype) acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
scores_max_0 = T.alloc_fragment([block_H], accum_dtype) scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_1 = T.alloc_fragment([block_H], accum_dtype) scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
scores_max = T.alloc_shared([block_H], accum_dtype)
scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype) scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype) scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
scores_sum_0 = T.alloc_fragment([block_H], accum_dtype) scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
scores_sum_1 = T.alloc_fragment([block_H], accum_dtype) scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
logsum_0 = T.alloc_fragment([block_H], accum_dtype) logsum_0 = T.alloc_fragment([block_H], accum_dtype)
logsum_1 = T.alloc_fragment([block_H], accum_dtype) logsum_1 = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_shared([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H) cur_kv_head = hid // (kv_group_num // block_H)
...@@ -69,22 +88,25 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -69,22 +88,25 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r), O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
}) })
# barriers_Q
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
# barriers_K0
kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
# barriers_K1
kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128) kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
# redundant barriers
score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128) score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
scale_1_ready_barrier = T.alloc_barrier(arrive_count=128) scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128) p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_0_ready_barrier = T.alloc_barrier(arrive_count=128) lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
lse_1_ready_barrier = T.alloc_barrier(arrive_count=128) lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128)
k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128)
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128) s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
k_shared_1_l_free_barrier = T.alloc_barrier(arrive_count=128)
tx = T.get_thread_binding() tx = T.get_thread_binding()
...@@ -93,11 +115,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -93,11 +115,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.barrier_arrive(q_shared_ready_barrier) T.barrier_arrive(q_shared_ready_barrier)
T.barrier_wait(q_shared_ready_barrier, 0) T.barrier_wait(q_shared_ready_barrier, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, (block_N * 2)) loop_range = T.ceildiv(seqlen_kv, (block_N * 2))
if tx < 128: if tx < 128:
T.copy(Q_pe_shared, Q_pe_local_0)
T.fill(acc_o_l, 0) T.fill(acc_o_l, 0)
T.fill(logsum_0, 0) T.fill(logsum_0, 0)
...@@ -118,7 +142,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -118,7 +142,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
KV_shared_0_l, KV_shared_0_l,
acc_s_0, acc_s_0,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True, clear_accum=True,
wg_wait=-1) wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2) T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
...@@ -127,16 +150,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -127,16 +150,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
KV_shared_0_r, KV_shared_0_r,
acc_s_0, acc_s_0,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1) wg_wait=-1)
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2) T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
T.gemm( T.gemm(
Q_pe_shared, Q_pe_local_0,
K_pe_shared_0, K_pe_shared_0,
acc_s_0, acc_s_0,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1) wg_wait=-1)
T.wait_wgmma(0) T.wait_wgmma(0)
...@@ -158,7 +179,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -158,7 +179,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.reduce_sum(acc_s_0, scores_sum_0, dim=1) T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
# Step 5. # Step 5.
T.copy(acc_s_0, S_shared) T.copy(acc_s_0, acc_s_0_cast)
for i, j in T.Parallel(block_H, h_dim): for i, j in T.Parallel(block_H, h_dim):
acc_o_l[i, j] *= scores_scale_0[i] acc_o_l[i, j] *= scores_scale_0[i]
...@@ -167,7 +188,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -167,7 +188,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i] logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]
# Step 6. # Step 6.
T.gemm(S_shared, KV_shared_0_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol) T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l)
T.barrier_arrive(score_max_0_ready_barrier) T.barrier_arrive(score_max_0_ready_barrier)
T.barrier_wait(scale_1_ready_barrier, k % 2) T.barrier_wait(scale_1_ready_barrier, k % 2)
...@@ -180,7 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -180,7 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
# Step 11. # Step 11.
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
S_shared_[i, j] = acc_s_0[i, j] * scores_scale_1[i] SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i]
T.barrier_arrive(p0_1_1_ready_barrier) T.barrier_arrive(p0_1_1_ready_barrier)
...@@ -192,19 +213,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -192,19 +213,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(s_shared_ready_barrier, k % 2) T.barrier_wait(s_shared_ready_barrier, k % 2)
# Step 14. # Step 14.
T.gemm(S_shared, KV_shared_1_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol) T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(k_pe_shared_0_free_barrier)
T.barrier_arrive(k_shared_1_l_free_barrier)
if k < loop_range - 1: if k < loop_range - 1:
T.barrier_wait(k_shared_1_l_free_barrier, k % 2)
T.copy( T.copy(
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
cur_kv_head, :h_dim], KV_shared_1_l) cur_kv_head, :h_dim], KV_shared_1_l)
T.barrier_arrive(kv_shared_1_l_is_ready) T.barrier_arrive(kv_shared_1_l_is_ready)
T.barrier_wait(k_pe_shared_1_free_barrier, k % 2)
T.copy( T.copy(
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :], K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
K_pe_shared_1) K_pe_shared_1)
...@@ -220,6 +237,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -220,6 +237,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim]) hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
else: else:
T.copy(Q_pe_shared, Q_pe_local_1)
T.fill(acc_o_r, 0) T.fill(acc_o_r, 0)
T.fill(logsum_1, 0) T.fill(logsum_1, 0)
...@@ -239,7 +257,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -239,7 +257,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
KV_shared_1_l, KV_shared_1_l,
acc_s_1, acc_s_1,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True, clear_accum=True,
wg_wait=-1) wg_wait=-1)
...@@ -249,16 +266,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -249,16 +266,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
KV_shared_1_r, KV_shared_1_r,
acc_s_1, acc_s_1,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1) wg_wait=-1)
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2) T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
T.gemm( T.gemm(
Q_pe_shared, Q_pe_local_1,
K_pe_shared_1, K_pe_shared_1,
acc_s_1, acc_s_1,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1) wg_wait=-1)
T.wait_wgmma(0) T.wait_wgmma(0)
...@@ -292,14 +307,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -292,14 +307,14 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_arrive(scale_1_ready_barrier) T.barrier_arrive(scale_1_ready_barrier)
# Step 10. compute O1 with KV_shared_1_rd # Step 10. compute O1 with KV_shared_1_rd
T.copy(acc_s_1, S_shared) T.copy(acc_s_1, acc_s_1_cast)
T.barrier_arrive(s_shared_ready_barrier)
T.gemm( T.gemm(
S_shared, acc_s_1_cast,
KV_shared_1_r, KV_shared_1_r,
acc_o_r, acc_o_r,
policy=T.GemmWarpPolicy.FullCol,
wg_wait=-1) wg_wait=-1)
T.copy(acc_s_1_cast, SP1_shared)
T.barrier_arrive(s_shared_ready_barrier)
if k < loop_range - 1: if k < loop_range - 1:
T.copy( T.copy(
...@@ -309,8 +324,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -309,8 +324,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(p0_1_1_ready_barrier, k % 2) T.barrier_wait(p0_1_1_ready_barrier, k % 2)
# Step 12. # Step 12.
T.gemm(S_shared_, KV_shared_0_r, acc_o_r, policy=T.GemmWarpPolicy.FullCol) T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(k_pe_shared_1_free_barrier)
if k < loop_range - 1: if k < loop_range - 1:
...@@ -319,7 +333,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -319,7 +333,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
h_dim:], KV_shared_0_r) h_dim:], KV_shared_0_r)
T.barrier_arrive(kv_shared_0_r_is_ready) T.barrier_arrive(kv_shared_0_r_is_ready)
T.barrier_wait(k_pe_shared_0_free_barrier, k % 2)
T.copy( T.copy(
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :], K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
K_pe_shared_0) K_pe_shared_0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment