Commit 34de04a6 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[CI][Test] Add test cases for tilelang kernel convolution (#51)

* [CI][Test] Add test cases for tilelang kernel convolution
parent 47ecc791
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from functools import partial
num_split = 4
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Buffer(shape_kv, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
mid: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared)
# TODO: Handle casual split case
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(shape_kv, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.macro
def flash_attn_split(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype),
):
with T.Kernel(
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
mid = bx
hid = by % heads
bid = by // heads
sid = bz
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle casual split case
loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
(mid + 1) * block_M, block_N)) if is_casual else T.ceildiv(
(seqlen_kv // num_split), block_N))
for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :])
@T.macro
def combine(
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype),
Output: T.Buffer(shape_q, dtype),
):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype)
po_shared = T.alloc_shared([block_M, dim], dtype)
o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype)
o_shared = T.alloc_shared([block_M, dim], dtype)
lse_local = T.alloc_fragment([num_split, block_M], dtype)
lse_local_split = T.alloc_fragment([block_M], accum_dtype)
lse_logsum_local = T.alloc_fragment([block_M], accum_dtype)
lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
lse_local_split: T.Fragment(lse_local_split.shape, forward_thread_fn=lambda i: i),
o_shared: tilelang.layout.make_swizzled_layout(o_shared),
po_shared: tilelang.layout.make_swizzled_layout(po_shared),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
T.copy(glse[
bz,
by,
:,
bx * block_M:(bx + 1) * block_M,
], lse_local)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M):
lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i])
for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2):
T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared)
T.copy(po_shared, po_local)
T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M):
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
@T.prim_func
def main(
Q: T.Buffer(shape_q, dtype),
K: T.Buffer(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Buffer(shape_q, dtype),
):
flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
def ref_program(Q, K, V, glse, Output_partial, casual):
assert casual is False
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def reduce_ref(Q, K, V, glse, Output_partial, casual):
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks, :]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks, :]
scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q]
o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2)
return o.to(torch.float16)
def flash_split_ref(Q, K, V, casual):
# [batch, seqlen_q, heads, dim]
batch = Q.size(0)
block_M = Q.size(1)
nheads = Q.size(2)
dim = Q.size(3)
block_N = 128
seqlen_kv = K.size(1)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float)
Q_ = Q * scale
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_,
K[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max)
acc_o *= scores_scale[:, :, :, None].transpose(1, 2)
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum(
'bhqk,bkhd->bqhd', acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2)
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0,
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
if __name__ == "__main__":
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
casual = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5
BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, casual=casual)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks passed!")
latency = mod.do_bench(ref_program, warmup=500)
print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm")
print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M, block_N,
block_K, num_stages, threads):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func
def main(
data: T.Buffer((N, H, W, C), in_dtype),
kernel: T.Buffer((KH, KW, C, F), in_dtype),
out: T.Buffer((N, OH, OW, F), out_dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), in_dtype)
kernel_shared = T.alloc_shared((block_K, block_N), in_dtype)
out_local = T.alloc_fragment((block_M, block_N), dtypeAccum)
kernel_flat = T.Buffer((KH * KW * C, F), in_dtype, kernel.data)
out_flat = T.Buffer((N * OH * OW, F), out_dtype, out.data)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i,
j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h,
access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main
def run_conv(N,
C,
H,
W,
F,
K,
S,
D,
P,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
threads=128):
program = convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M,
block_N, block_K, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
A = A.permute(0, 3, 1, 2).to(torch.float) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1).to(torch.float) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C.to(torch.__getattribute__(out_dtype))
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_conv_f16f16f16_k3s1d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
1,
1,
1,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
def test_conv_f16f16f16_k3s2d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
2,
1,
1,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
def test_conv_f16f16f16_k1s1d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
1,
1,
0,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
def test_conv_f16f16f16_k1s2d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
2,
1,
0,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k3s1d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
1,
1,
1,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k3s2d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
2,
1,
1,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k1s1d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
1,
1,
0,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k1s2d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
2,
1,
0,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k3s1d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
1,
1,
1,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k3s2d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
2,
1,
1,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k1s1d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
1,
1,
0,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k1s2d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
2,
1,
0,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment