Commit bedab1a0 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[CI][Test] Add test cases for tilelang kernel FlashAttention (#54)

* [Dev] Add FlashDecoding example

* [CI][Test] Add test cases for tilelang kernel convolution

* [CI][Test] Add test cases for tilelang kernel FlashAttention

* Reduce the number of stages to ensure the shared memory allocation is valid

* Temporarily remove the dim128 case

* lint

* update einops in requirements-dev.txt

* update einops in requirements-test.txt

* remove einops in requirements-dev.txt
parent 38ba083b
...@@ -45,7 +45,7 @@ def flash_attention( ...@@ -45,7 +45,7 @@ def flash_attention(
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N) T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
) )
# Pipeline the loop to overlap copies/gemm stages # Pipeline the loop to overlap copies/gemm stages
...@@ -53,7 +53,7 @@ def flash_attention( ...@@ -53,7 +53,7 @@ def flash_attention(
# Copy K block into shared memory # Copy K block into shared memory
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_casual: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype) bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
......
...@@ -28,7 +28,7 @@ def get_configs(): ...@@ -28,7 +28,7 @@ def get_configs():
return configs return configs
def flashattn(batch, heads, seq_len, dim, is_casual, tune=False): def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
...@@ -48,7 +48,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False): ...@@ -48,7 +48,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False):
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
...@@ -136,7 +136,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False): ...@@ -136,7 +136,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False):
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_casual else T.ceildiv(seq_len, block_N)) (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
...@@ -175,11 +175,11 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False): ...@@ -175,11 +175,11 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False):
return kernel return kernel
def ref_program(Q, K, V, is_casual): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_casual: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
...@@ -195,20 +195,20 @@ if __name__ == "__main__": ...@@ -195,20 +195,20 @@ if __name__ == "__main__":
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_casual', action='store_true', help='causal') parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args() args = parser.parse_args()
batch, heads, seq_len, dim, is_casual = args.batch, args.heads, args.seq_len, args.dim, args.is_casual batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if is_casual: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not args.tune): if (not args.tune):
program = flashattn( program = flashattn(
batch, heads, seq_len, dim, is_casual, tune=args.tune)( batch, heads, seq_len, dim, is_causal, tune=args.tune)(
block_M=128, block_N=128, num_stages=1, threads=128) block_M=128, block_N=128, num_stages=1, threads=128)
ref_program = partial(ref_program, is_casual=is_casual) ref_program = partial(ref_program, is_causal=is_causal)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
...@@ -221,7 +221,7 @@ if __name__ == "__main__": ...@@ -221,7 +221,7 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_latency, best_config, _ = flashattn( best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_casual, tune=args.tune) batch, heads, seq_len, dim, is_causal, tune=args.tune)
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
...@@ -8,7 +8,7 @@ from functools import partial ...@@ -8,7 +8,7 @@ from functools import partial
num_split = 4 num_split = 4
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim]
...@@ -31,8 +31,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -31,8 +31,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared) (k + 1) * block_N, hid, :], K_shared)
# TODO: Handle casual split case # TODO: Handle causal split case
if is_casual: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
...@@ -129,10 +129,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -129,10 +129,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle casual split case # TODO: Handle causal split case
loop_range = ( loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
(mid + 1) * block_M, block_N)) if is_casual else T.ceildiv( (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv(
(seqlen_kv // num_split), block_N)) (seqlen_kv // num_split), block_N))
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
...@@ -214,8 +214,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -214,8 +214,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
return main return main
def ref_program(Q, K, V, glse, Output_partial, casual): def ref_program(Q, K, V, glse, Output_partial, causal):
assert casual is False assert causal is False
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
...@@ -224,7 +224,7 @@ def ref_program(Q, K, V, glse, Output_partial, casual): ...@@ -224,7 +224,7 @@ def ref_program(Q, K, V, glse, Output_partial, casual):
return output return output
def reduce_ref(Q, K, V, glse, Output_partial, casual): def reduce_ref(Q, K, V, glse, Output_partial, causal):
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0) o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads] lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
lse_max = glse.max(dim=2, keepdim=False).values lse_max = glse.max(dim=2, keepdim=False).values
...@@ -239,7 +239,7 @@ def reduce_ref(Q, K, V, glse, Output_partial, casual): ...@@ -239,7 +239,7 @@ def reduce_ref(Q, K, V, glse, Output_partial, casual):
return o.to(torch.float16) return o.to(torch.float16)
def flash_split_ref(Q, K, V, casual): def flash_split_ref(Q, K, V, causal):
# [batch, seqlen_q, heads, dim] # [batch, seqlen_q, heads, dim]
batch = Q.size(0) batch = Q.size(0)
block_M = Q.size(1) block_M = Q.size(1)
...@@ -296,15 +296,15 @@ def flash_split_ref(Q, K, V, casual): ...@@ -296,15 +296,15 @@ def flash_split_ref(Q, K, V, casual):
if __name__ == "__main__": if __name__ == "__main__":
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
casual = False causal = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if casual: if causal:
total_flops *= 0.5 total_flops *= 0.5
BLOCK_M = 128 BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, casual=casual) ref_program = partial(ref_program, causal=causal)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
...@@ -8,7 +8,7 @@ from functools import partial ...@@ -8,7 +8,7 @@ from functools import partial
num_split = 4 num_split = 4
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim]
...@@ -31,8 +31,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -31,8 +31,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared) (k + 1) * block_N, hid, :], K_shared)
# TODO: Handle casual split case # TODO: Handle causal split case
if is_casual: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
...@@ -128,10 +128,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -128,10 +128,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle casual split case # TODO: Handle causal split case
loop_range = ( loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
(mid + 1) * block_M, block_N)) if is_casual else T.ceildiv( (mid + 1) * block_M, block_N)) if is_causal else T.ceildiv(
(seqlen_kv // num_split), block_N)) (seqlen_kv // num_split), block_N))
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
...@@ -213,8 +213,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -213,8 +213,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
return main return main
def ref_program(Q, K, V, glse, Output_partial, casual): def ref_program(Q, K, V, glse, Output_partial, causal):
assert casual is False assert causal is False
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
...@@ -223,7 +223,7 @@ def ref_program(Q, K, V, glse, Output_partial, casual): ...@@ -223,7 +223,7 @@ def ref_program(Q, K, V, glse, Output_partial, casual):
return output return output
def reduce_ref(Q, K, V, glse, Output_partial, casual): def reduce_ref(Q, K, V, glse, Output_partial, causal):
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0) o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads] lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
lse_max = glse.max(dim=2, keepdim=False).values lse_max = glse.max(dim=2, keepdim=False).values
...@@ -238,7 +238,7 @@ def reduce_ref(Q, K, V, glse, Output_partial, casual): ...@@ -238,7 +238,7 @@ def reduce_ref(Q, K, V, glse, Output_partial, casual):
return o.to(torch.float16) return o.to(torch.float16)
def flash_split_ref(Q, K, V, casual): def flash_split_ref(Q, K, V, causal):
# [batch, seqlen_q, heads, dim] # [batch, seqlen_q, heads, dim]
batch = Q.size(0) batch = Q.size(0)
block_M = Q.size(1) block_M = Q.size(1)
...@@ -295,15 +295,15 @@ def flash_split_ref(Q, K, V, casual): ...@@ -295,15 +295,15 @@ def flash_split_ref(Q, K, V, casual):
if __name__ == "__main__": if __name__ == "__main__":
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
casual = False causal = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if casual: if causal:
total_flops *= 0.5 total_flops *= 0.5
BLOCK_M = 128 BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32 BLOCK_N = 64 # if D_HEAD <= 128 else 32
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_program = partial(ref_program, casual=casual) ref_program = partial(ref_program, causal=causal)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
...@@ -34,3 +34,4 @@ thefuzz ...@@ -34,3 +34,4 @@ thefuzz
tabulate tabulate
wheel wheel
setuptools setuptools
einops
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M, block_N,
block_K, block_Dstate, num_stages, threads):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
@T.prim_func
def main(cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Buffer(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype),
C: T.Buffer((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Buffer(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Buffer(
(nheads), dtype), Output: T.Buffer((batch, seqlen, nheads, headdim), dtype)):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
cb_local = T.alloc_fragment((block_M, block_K), dtype)
dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
dt_local = T.alloc_fragment((block_K), accum_dtype)
x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
scale_m_local = T.alloc_fragment((block_M), accum_dtype)
C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
D_local = T.alloc_fragment((1), accum_dtype)
x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)
batch_idx = by % batch
chunk_idx = by // batch
# m: chunk_size
# n : headdim
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
T.copy(
acc_o,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
return main
def run_chunk_scan(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M,
block_N,
block_K,
block_Dstate,
num_stages=2,
threads=128):
program = chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, block_Dstate, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [7], tl.TensorSupplyType.Integer)
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
import torch
from einops import rearrange, repeat
"""
Argument:
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
C: (batch, seqlen, ngroups, dstate)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
_, _, ngroups, _, _ = cb.shape
batch, seqlen, nheads, headdim = x.shape
# _, _, ngroups, dstate = B.shape
# assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
# assert C.shape == B.shape
# B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups)
# CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
# rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
# (batch, nheads, nchunks, chunksize, chunksize)
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp',
rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
prev_states.to(C.dtype)) * state_decay_out
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
if D.dim() == 1:
D = rearrange(D, "h -> h 1")
out = out + x * D
return out
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def chunk_state_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M,
block_N,
block_K,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
@T.prim_func
def main(B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), x: T.Buffer(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Buffer(
(batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel(
nheads,
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
dt_shared = T.alloc_shared((block_K), dtype)
dA_cumsum_shared = T.alloc_shared((block_K), dtype)
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
scale = T.alloc_fragment((block_K), accum_dtype)
dA_cs_last = T.alloc_fragment((1), accum_dtype)
dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype)
dt_local = T.alloc_fragment((block_K), accum_dtype)
loop_range = T.ceildiv(chunk_size, block_K)
batch_idx = by % batch
chunk_idx = by // batch
m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N)
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K):
scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i]
T.copy(x_shared, x_local)
for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j]
T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz // (nheads // ngroups),
n_idx * block_N:(n_idx + 1) * block_N], B_shared)
T.gemm(xt_local, B_shared, acc_o)
T.copy(
acc_o, Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
n_idx * block_N:(n_idx + 1) * block_N])
return main
def run_chunk_state(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M,
block_N,
block_K,
num_stages=2,
threads=128):
program = chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M,
block_N, block_K, num_stages, threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [4], tl.TensorSupplyType.Integer)
def ref_program(B, x, dt, dA_cumsum):
"""
Argument:
B: (batch, seqlen, ngroups, headdim)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
Return:
states: (batch, nchunks, nheads, headdim, dstate)
"""
# Check constraints.
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
batch, seqlen, nheads, headdim = x.shape
dstate = B.shape[-1]
_, _, nchunks, chunk_size = dt.shape
assert seqlen <= nchunks * chunk_size
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
ngroups = B.shape[2]
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if seqlen < nchunks * chunk_size:
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
dt.to(x.dtype), x)
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_chunk_scan():
run_chunk_scan(
batch=8,
seqlen=2048,
chunk_size=256,
ngroups=1,
nheads=8,
headdim=64,
dstate=128,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128)
def test_chunk_state():
run_chunk_state(
batch=8,
seqlen=2048,
chunk_size=256,
ngroups=1,
nheads=8,
headdim=64,
dstate=128,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages, threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Buffer(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=2, threads=128):
program = flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages,
threads)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Integer)
def ref_program(Q, K, V):
import torch
import torch.nn.functional as F
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_mha_causal_dim64():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=64,
is_causal=True,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
def test_mha_no_causal_dim64():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=64,
is_causal=False,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
# def test_mha_causal_dim128():
# run_mha(
# batch=4,
# heads=8,
# seq_len=8192,
# dim=128,
# is_causal=True,
# block_M=64,
# block_N=64,
# num_stages=1,
# threads=128)
# def test_mha_no_causal_dim128():
# run_mha(
# batch=4,
# heads=8,
# seq_len=8192,
# dim=128,
# is_causal=False,
# block_M=64,
# block_N=64,
# num_stages=1,
# threads=128)
def test_mha_causal_dim256():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=256,
is_causal=True,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
def test_mha_no_causal_dim256():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=256,
is_causal=False,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment