Commit 5cea760c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add Split-K and Stream-K Examples and move MLA from fld to mla (#110)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity
parent 40faabb1
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
num_split = 4
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, (dim + pe_dim)]
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
shape_v = [batch, seqlen_kv, kv_head_num, dim]
shape_o = [batch, heads, dim]
part_shape = [batch, heads, num_split, dim]
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn_split(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_k, dtype),
V: T.Buffer(shape_v, dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype)
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
sid, :])
@T.macro
def combine(
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype),
Output: T.Buffer(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 1], dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k in T.Parallel(num_split):
lse_local[k, 0] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_k, dtype),
V: T.Buffer(shape_v, dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim]
Output: T.Buffer(shape_o, dtype),
):
flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
def ref_program(query, key, value, glse, Output_partial):
# """
# Inputs:
# - query (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
from einops import rearrange
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim]
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
dim_v = value.shape[-1]
assert kv_heads == 1, "kv_heads must be 1"
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim]
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim]
value_expanded = value.expand(-1, -1, query_heads,
-1) # [batch_size, query_heads, seqlen_kv, dim]
key_expanded = rearrange(key_expanded,
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim]
value_expanded = rearrange(value_expanded,
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim]
scores = torch.matmul(query_expanded,
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv]
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv]
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim]
return output.view(batch_size, query_heads, dim_v)
def flash_split_ref(Q, K, V):
dim = 512
pe_dim = 64
batch = Q.size(0)
nheads = Q.size(1)
assert Q.size(2) == dim + pe_dim, "dim must be 576=512+64"
block_N = 32
seqlen_kv = K.size(1)
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
K_ = K.expand(-1, -1, nheads, -1)
V_ = V.expand(-1, -1, nheads, -1)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bhd,bkhd->bhk', Q_,
K_[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
acc_o *= scores_scale[:, :, None]
acc_s = torch.exp2(acc_s - scores_max[:, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
'bhk,bkhd->bhd', acc_s_cast,
V_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, None]
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :] = acc_o
glogsum[ks, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
def reduce_ref(Q, K, V, glse, Output_partial):
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0)
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks]
scale = torch.exp2(lse - lse_logsum)
o += Output_partial[:, :, ks, :] * scale[:, :, None]
return o.to(torch.float16)
if __name__ == "__main__":
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
total_flops = qk_flops + pv_flops
BLOCK_N = 32 # if D_HEAD <= 128 else 32
BLOCK_H = 64
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.language as T
from tvm import DataType
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"):
splitK = K // split_k
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_shared = T.alloc_shared((block_M, block_N), dtype, "shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if bz == 0:
# fuse the zero initialization kernel
for i, j in T.Parallel(block_M, block_N):
m, n = by * block_M + i, bx * block_N + j
C[m, n] = T.cast(0, dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
if DataType(dtype).bits == 16:
for i, j in T.Parallel(block_M, block_N // 2):
m, n = by * block_M + i, bx * block_N + j * 2
# vectorized atomic
T.atomic_addx2(C[m, n], C_shared[i, j * 2])
else:
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
return main
program = matmul(1024, 1024, 1024, 128, 128, 32, 4)
kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = torch.zeros(1024, 1024).cuda().half()
kernel(a, b, c)
ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
import torch
import torch.backends
import tilelang
from tilelang import language as T
import math
def cdiv(a, b):
return math.ceil(a / b)
# disable tf32
torch.backends.cuda.matmul.allow_tf32 = False
m = 256
n = 1024
k = 512
total_sm = 108
torch.random.manual_seed(0)
# uniform distribution from -1 to 1
A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1
B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1
streamk_programs = total_sm
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 32
two_tiles = False
M, K = A.shape
N, K = B.shape
# accumulator types
# compute grid (work to do per SM on the first wave)
num_block_m = cdiv(M, BLOCK_SIZE_M)
num_block_n = cdiv(N, BLOCK_SIZE_N)
iters_per_tile = cdiv(K, BLOCK_SIZE_K)
total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1)
streamk_tiles += streamk_programs
blocking_tiles = total_tiles - streamk_tiles
streamk_iters = streamk_tiles * iters_per_tile
streamk_full_tiles = streamk_iters // streamk_programs
streamk_partial_tiles = streamk_iters % streamk_programs
print(f"{total_tiles=} ")
print(f"{iters_per_tile=} ")
sm_patition_factor = max(blocking_tiles // total_sm, 1)
def tl_matmul_streamk(
M,
N,
K,
streamk_tiles,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
assert not trans_A
A_shape = (M, K) if not trans_A else (K, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
@T.macro
def compute_first_wave(
pid: T.int32,
A_buf: T.Buffer,
A_buf_shared: T.Buffer,
B_buf: T.Buffer,
B_buf_shared: T.Buffer,
C: T.Buffer,
C_local: T.Buffer,
):
start_iter = T.alloc_fragment((1,), "int32", "local")
end_iter = T.alloc_fragment((1,), "int32", "local")
start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)
while start_iter[0] < last_iter:
end_iter[0] = T.min(
start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)),
last_iter,
)
tile_id = start_iter[0] // iters_per_tile
remain_iters = start_iter[0] % iters_per_tile
pid_m = tile_id // T.ceildiv(N, block_N)
pid_n = tile_id % T.ceildiv(N, block_N)
T.clear(C_local)
for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages):
T.copy(
A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K],
A_buf_shared,
)
T.copy(
B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K],
B_buf_shared,
)
T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B)
# last iteration of the tile always happens before its start on another SM
if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0):
T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
else:
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j])
start_iter[0] = end_iter[0]
@T.macro
def compute_full_tiles(
pid: T.int32,
A_buf: T.Buffer,
A_shared: T.Buffer,
B_buf: T.Buffer,
B_shared: T.Buffer,
C: T.Buffer,
C_local: T.Buffer,
):
for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N)
pid_n = tile_id % T.ceildiv(N, block_N)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A_buf[pid_m * block_M, k * block_K], A_shared)
T.copy(B_buf[pid_n * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B)
T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local)
if sm_patition_factor > 0:
compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local)
return main
_tl_matmul_streamk = tl_matmul_streamk(
m,
n,
k,
streamk_tiles,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
False,
True,
"float16",
"float16",
"float32",
2,
64,
)
kernel = tilelang.compile(_tl_matmul_streamk)
print(kernel.get_kernel_source())
b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)
kernel(A, B, b_c)
C = torch.matmul(A, B.T)
print(b_c)
print(C)
torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment