Commit b70683b3 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Doc] Update MLA Documentation (#135)

parent cd94aca1
...@@ -4,7 +4,7 @@ TileLang is a user-friendly AI programming language that significantly lowers th ...@@ -4,7 +4,7 @@ TileLang is a user-friendly AI programming language that significantly lowers th
## Introduction to MLA ## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance. Subsequently, various deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have released their own MLA implementations. DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results ## Benchmark Results
......
...@@ -40,31 +40,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -40,31 +40,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
scores_sum = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx cur_kv_head = by // (kv_group_num // block_H)
hid = by
cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared), S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}) })
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N) loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = k * block_N T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
kv_end = (k + 1) * block_N T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s_0) T.clear(acc_s_0)
T.gemm( T.gemm(
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
...@@ -92,9 +84,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -92,9 +84,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :]) T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
...@@ -123,32 +114,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -123,32 +114,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
scores_sum = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx cur_kv_head = by // (kv_group_num // block_H)
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared), S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}) })
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared) T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * sid + k * block_N kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s_0) T.clear(acc_s_0)
T.gemm( T.gemm(
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
...@@ -178,11 +163,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -178,11 +163,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
sid, :])
@T.macro @T.macro
def combine( def combine(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment