Commit 3d7b2dc5 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev][Doc] Enhance Flash Attention Implementation in GQA Decoding Example and Fix Typo (#139)

- Add non-split flash attention macro for more flexible kernel generation
- Implement `main_no_split` function to handle single-split scenarios
- Modify kernel selection logic to dynamically choose between split and non-split implementations
parent 3960d3d0
...@@ -116,7 +116,7 @@ Here, `T.annotate_layout` allows users to specify any desired layout for a buffe ...@@ -116,7 +116,7 @@ Here, `T.annotate_layout` allows users to specify any desired layout for a buffe
### Warp-Specialization ### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Access), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects. The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation. In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
......
...@@ -40,6 +40,75 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -40,6 +40,75 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
part_shape = [batch, heads, num_split, dim] part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
@T.macro
def flash_attn(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_k, dtype),
V: T.Buffer(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"),
Output: T.Buffer([batch, heads, dim], dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer(shape_q, dtype),
...@@ -168,7 +237,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -168,7 +237,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
Output[bz, by, i] = o_accum_local[i] Output[bz, by, i] = o_accum_local[i]
@T.prim_func @T.prim_func
def main( def main_split(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_k, dtype), K: T.Buffer(shape_k, dtype),
V: T.Buffer(shape_v, dtype), V: T.Buffer(shape_v, dtype),
...@@ -180,7 +249,22 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -180,7 +249,22 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
flash_attn_split(Q, K, V, mask, glse, Output_partial) flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
return main @T.prim_func
def main_no_split(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_k, dtype),
V: T.Buffer(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype),
Output: T.Buffer(shape_o, dtype),
):
flash_attn(Q, K, V, mask, Output)
if num_split > 1:
return main_split
else:
return main_no_split
if tune: if tune:
...@@ -349,6 +433,7 @@ if __name__ == "__main__": ...@@ -349,6 +433,7 @@ if __name__ == "__main__":
block_N=128, block_H=64, num_split=8, num_stages=2, threads=128) block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
kernel = tilelang.compile(program, out_idx=[6]) kernel = tilelang.compile(program, out_idx=[6])
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} ms".format(latency))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment