"vscode:/vscode.git/clone" did not exist on "192a399506b99691f1d8fee9cb484cce64ae6284"
Commit e3065f0b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AMD] Implement Deepseek MLA for AMD (#363)

* [Bugfix] Correct dynamic shared memory size error handling in HIP wrapper

- Updated the error handling logic in `PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP` to check if the dynamic shared memory size exceeds the maximum limit of 65536.
- Improved error message clarity by specifying the function name and the attempted size, ensuring better debugging information.
- Ensured the function returns 0 upon successful setting of the dynamic shared memory size.

* [Add] Implement example for MLA decoding with AMD support

- Introduced a new example script `example_mla_decode_amd.py` demonstrating the use of the flash attention mechanism with AMD hardware.
- Implemented functions for attention calculation, including support for split processing and combining outputs.
- Added command-line argument parsing for customizable input parameters such as batch size, number of heads, and dimensions.
- Included a reference implementation for validation against the Tile-AI output, ensuring correctness of the implementation.
- Enhanced performance profiling and output comparison for debugging and optimization purposes.

* lint fix
parent 0f4a3215
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
tilelang.disable_cache()
def flashattn(batch,
heads,
kv_head_num,
seqlen_kv,
dim,
pe_dim,
block_N,
block_H,
num_split,
threads=128):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by):
Q_shared = T.alloc_fragment([block_H, dim], dtype)
Q_pe_shared = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# T.copy(acc_s, S_shared)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split,
threads=threads) as (bx, by, bz):
Q_shared = T.alloc_fragment([block_H, dim], dtype)
Q_pe_shared = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=0):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')
parser.add_argument('--dim', type=int, default=256, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 32
BLOCK_H = 64
num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
ref_output = ref_program(*input_tensors)
print(f"Tilelang output: {tilelang_output}")
print(f"Ref output: {ref_output}")
torch.testing.assert_allclose(tilelang_output, ref_output, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
...@@ -146,7 +146,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, ...@@ -146,7 +146,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0); ICHECK(block_k % 16 == 0);
// Only support 8-bit and 16-bit // Only support 8-bit and 16-bit
ICHECK(element_size == 8 || element_size == 16); ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
if (element_size == 8) { if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false); auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
...@@ -175,6 +176,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, ...@@ -175,6 +176,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
ICHECK(block_n % warp_n == 0); ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0); ICHECK(block_k % 16 == 0);
ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
if (transposed) { if (transposed) {
auto base_layout = auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false); makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
......
...@@ -35,7 +35,6 @@ extern "C" const char* get_last_error() {{ ...@@ -35,7 +35,6 @@ extern "C" const char* get_last_error() {{
extern "C" int init() {{ extern "C" int init() {{
error_buf[0] = '\\0'; error_buf[0] = '\\0';
{0} {0}
return 0;
}} }}
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment