"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d452a1665f802bfbe75372b6c942272b252c70a2"
Commit b7ca76f1 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev] Update MLA decode kernel (#120)

parent 524991fe
...@@ -3,17 +3,13 @@ import torch.nn.functional as F ...@@ -3,17 +3,13 @@ import torch.nn.functional as F
import tilelang import tilelang
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
from einops import rearrange, einsum
num_split = 4 num_split = 1
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, (dim + pe_dim)]
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
shape_v = [batch, seqlen_kv, kv_head_num, dim]
shape_o = [batch, heads, dim]
part_shape = [batch, heads, num_split, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
...@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer([batch, heads, dim], dtype),
K: T.Buffer(shape_k, dtype), Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
V: T.Buffer(shape_v, dtype), KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
): ):
with T.Kernel( with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz): batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype)
...@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
}) })
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * sid + k * block_N
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
T.copy(
KV[bid, kv_start:kv_end, cur_kv_head, :],
KV_shared
)
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K_pe[bid, kv_start:kv_end, cur_kv_head, :],
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, K_pe_shared
cur_kv_head, :], K_shared) )
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.clear(acc_s_0)
T.gemm(Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(acc_s_0, S_shared)
T.copy(S_shared, acc_s)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
...@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
V[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
...@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro @T.macro
def combine( def combine(
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Buffer([batch, heads, dim], dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -133,50 +141,63 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -133,50 +141,63 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer([batch, heads, dim], dtype),
K: T.Buffer(shape_k, dtype), Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
V: T.Buffer(shape_v, dtype), KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim] Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Buffer([batch, heads, dim], dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
return main return main
def ref_program(query, key, value, glse, Output_partial):
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """ # """
# Inputs: # Inputs:
# - query (Tensor): [batch, heads, dim] # - q (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim] # - q_pe (Tensor): [batch, heads, pe_dim]
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim] # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs: # Outputs:
# - output (Tensor): [batch, heads, dim] # - output (Tensor): [batch, heads, dim]
# """ # """
from einops import rearrange dim = q.shape[-1]
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim] pe_dim = q_pe.shape[-1]
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim] num_head_groups = q.shape[1] // kv.shape[2]
dim_v = value.shape[-1] scale = (dim + pe_dim) ** 0.5
assert kv_heads == 1, "kv_heads must be 1" q = rearrange(
q, 'b (h g) d -> b g h d',
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim] g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim]
value_expanded = value.expand(-1, -1, query_heads, q_pe = rearrange(
-1) # [batch_size, query_heads, seqlen_kv, dim] q_pe, 'b (h g) d -> b g h d',
key_expanded = rearrange(key_expanded, g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim]
value_expanded = rearrange(value_expanded, kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
scores = torch.matmul(query_expanded,
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv] query = torch.concat([q, q_pe], dim=-1)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) key = torch.concat([kv, k_pe], dim=-1)
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv]
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim] scores = einsum(
return output.view(batch_size, query_heads, dim_v) query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def flash_split_ref(Q, K, V): def flash_split_ref(Q, K, V):
...@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial): ...@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
if __name__ == "__main__": if __name__ == "__main__":
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64 BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 128, 128, 1, 8192, 512, 64
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE) qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
...@@ -260,8 +281,9 @@ if __name__ == "__main__": ...@@ -260,8 +281,9 @@ if __name__ == "__main__":
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H) program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = mod.do_bench(mod.func, warmup=500) print("All close")
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment