"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a1f926661689b6db3bba72f7cd381aac118e5c0a"
Commit 0fd3a3e8 authored by Tong WU's avatar Tong WU Committed by LeiWang1999
Browse files

[Dev] Update linear attention examples to enhance performance on Hopper GPUs (#621)

* Tune linear attention examples on H100

* Add retnet fwd kernel

* fix lint
parent 67b81609
...@@ -7,7 +7,12 @@ import argparse ...@@ -7,7 +7,12 @@ import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
@tl.jit(out_idx=[4, 5, 6]) @tl.jit(
out_idx=[4, 5, 6],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def chunk_linear_attn_bwd_kernel( def chunk_linear_attn_bwd_kernel(
B, B,
S, S,
...@@ -23,21 +28,21 @@ def chunk_linear_attn_bwd_kernel( ...@@ -23,21 +28,21 @@ def chunk_linear_attn_bwd_kernel(
accum_dtype = 'float' accum_dtype = 'float'
chunk_size = 64 chunk_size = 64
BK = BV = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK) NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV) NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size) NT = tl.cdiv(S, chunk_size)
@T.prim_func @T.prim_func
def main( def chunk_linear_attn_bwd(
Q: T.Tensor([B, S, H, DK], dtype), Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), V: T.Tensor([B, S, H, DV], dtype), # type: ignore
dO: T.Tensor([B, S, H, DV], dtype), dO: T.Tensor([B, S, H, DV], dtype), # type: ignore
dQ: T.Tensor([NV, B, S, H, DK], dtype), dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dK: T.Tensor([NV, B, S, H, DK], dtype), dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore
dV: T.Tensor([NK, B, S, H, DV], dtype), dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
): ):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
...@@ -68,6 +73,7 @@ def chunk_linear_attn_bwd_kernel( ...@@ -68,6 +73,7 @@ def chunk_linear_attn_bwd_kernel(
h_shared: tl.layout.make_swizzled_layout(h_shared), h_shared: tl.layout.make_swizzled_layout(h_shared),
dh_shared: tl.layout.make_swizzled_layout(dh_shared) dh_shared: tl.layout.make_swizzled_layout(dh_shared)
}) })
T.use_swizzle(10)
# Calculate dQ # Calculate dQ
for i in T.Pipelined(0, NT, num_stages=1): for i in T.Pipelined(0, NT, num_stages=1):
...@@ -104,7 +110,6 @@ def chunk_linear_attn_bwd_kernel( ...@@ -104,7 +110,6 @@ def chunk_linear_attn_bwd_kernel(
T.copy( T.copy(
dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV], do) i_v * BV:(i_v + 1) * BV], do)
T.copy(dh, dh_shared)
# Calculate dk # Calculate dk
T.gemm( T.gemm(
...@@ -113,6 +118,7 @@ def chunk_linear_attn_bwd_kernel( ...@@ -113,6 +118,7 @@ def chunk_linear_attn_bwd_kernel(
for row, col in T.Parallel(chunk_size, chunk_size): for row, col in T.Parallel(chunk_size, chunk_size):
ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0)
T.gemm(ds_shared, q, dk, clear_accum=True) T.gemm(ds_shared, q, dk, clear_accum=True)
T.copy(dh, dh_shared)
T.gemm(v, dh_shared, dk, transpose_B=True) T.gemm(v, dh_shared, dk, transpose_B=True)
# Calculate dv # Calculate dv
...@@ -132,7 +138,7 @@ def chunk_linear_attn_bwd_kernel( ...@@ -132,7 +138,7 @@ def chunk_linear_attn_bwd_kernel(
dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV]) i_v * BV:(i_v + 1) * BV])
return main return chunk_linear_attn_bwd
def postprocess(dQ, dK, dV): def postprocess(dQ, dK, dV):
...@@ -145,8 +151,8 @@ def postprocess(dQ, dK, dV): ...@@ -145,8 +151,8 @@ def postprocess(dQ, dK, dV):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size') parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=2048, help='Seq len') parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=64, help='Num heads') parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim') parser.add_argument('--D', type=int, default=256, help='Head dim')
args = parser.parse_args() args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D B, S, H, D = args.B, args.S, args.H, args.D
...@@ -158,7 +164,7 @@ def main(): ...@@ -158,7 +164,7 @@ def main():
kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D) kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D)
dq, dk, dv = postprocess(*kernel(q, k, v, do)) dq, dk, dv = postprocess(*kernel(q, k, v, do))
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
o_ref.backward(do, retain_graph=True) o_ref.backward(do, retain_graph=True)
if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad): if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad):
print('Passed all tests!✅') print('Passed all tests!✅')
...@@ -166,7 +172,7 @@ def main(): ...@@ -166,7 +172,7 @@ def main():
print('Failed some tests!❌') print('Failed some tests!❌')
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100) t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100)
q.grad = k.grad = v.grad = None q.grad = k.grad = v.grad = None
o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100) t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100)
print(f'Triton latency: {t1:.3f} ms') print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms') print(f'TileLang latency: {t2:.3f} ms')
......
...@@ -7,7 +7,12 @@ import argparse ...@@ -7,7 +7,12 @@ import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA
@tl.jit(out_idx=[3, 4]) @tl.jit(
out_idx=[3, 4],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
def chunk_linear_attn_fwd_kernel( def chunk_linear_attn_fwd_kernel(
B, B,
S, S,
...@@ -23,16 +28,19 @@ def chunk_linear_attn_fwd_kernel( ...@@ -23,16 +28,19 @@ def chunk_linear_attn_fwd_kernel(
accum_dtype = 'float' accum_dtype = 'float'
chunk_size = 64 chunk_size = 64
BK = BV = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK) NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV) NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size) NT = tl.cdiv(S, chunk_size)
@T.prim_func @T.prim_func
def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), def chunk_linear_attn_fwd(
V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype), Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
final_state: T.Tensor([B, H, DK, DV], accum_dtype)): # type: ignore
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H i_b = i_bh // H
i_h = i_bh % H i_h = i_bh % H
...@@ -54,9 +62,9 @@ def chunk_linear_attn_fwd_kernel( ...@@ -54,9 +62,9 @@ def chunk_linear_attn_fwd_kernel(
h_shared: tl.layout.make_swizzled_layout(h_shared), h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared), s_shared: tl.layout.make_swizzled_layout(s_shared),
}) })
T.use_swizzle(8) T.use_swizzle(10)
for i in T.Pipelined(0, NT, num_stages=1): for i in T.Pipelined(0, NT, num_stages=2):
for row, col in T.Parallel(chunk_size, BK): for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
...@@ -68,8 +76,8 @@ def chunk_linear_attn_fwd_kernel( ...@@ -68,8 +76,8 @@ def chunk_linear_attn_fwd_kernel(
T.gemm(s_shared, v, o, clear_accum=True) T.gemm(s_shared, v, o, clear_accum=True)
T.copy(h, h_shared) T.copy(h, h_shared)
T.gemm(q, h_shared, o)
T.gemm(k, v, h, transpose_A=True) T.gemm(k, v, h, transpose_A=True)
T.gemm(q, h_shared, o)
T.copy( T.copy(
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV]) i_v * BV:(i_v + 1) * BV])
...@@ -77,7 +85,7 @@ def chunk_linear_attn_fwd_kernel( ...@@ -77,7 +85,7 @@ def chunk_linear_attn_fwd_kernel(
# Output final state # Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
return main return chunk_linear_attn_fwd
def postprocess(o, h): def postprocess(o, h):
...@@ -88,8 +96,8 @@ def postprocess(o, h): ...@@ -88,8 +96,8 @@ def postprocess(o, h):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size') parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=2048, help='Seq len') parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=64, help='Num heads') parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=256, help='Head dim') parser.add_argument('--D', type=int, default=256, help='Head dim')
args = parser.parse_args() args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D B, S, H, D = args.B, args.S, args.H, args.D
...@@ -111,7 +119,7 @@ def main(): ...@@ -111,7 +119,7 @@ def main():
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0],
warmup=25, warmup=25,
rep=100) rep=100)
t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100) t2 = do_bench(lambda: postprocess(*kernel(q, k, v)), warmup=25, rep=100)
print(f'Triton latency: {t1:.3f} ms') print(f'Triton latency: {t1:.3f} ms')
print(f'TileLang latency: {t2:.3f} ms') print(f'TileLang latency: {t2:.3f} ms')
print(f'Speedup: {t1/t2:.3f}x') print(f'Speedup: {t1/t2:.3f}x')
......
import torch
import tilelang as tl
import tilelang.language as T
from tilelang.profiler import do_bench
import argparse
@tl.jit(out_idx=3, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def chunk_retention_fwd_kernel(
B,
S,
H,
DK,
DV,
dtype: str = 'float16',
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = 'float'
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
NK = tl.cdiv(DK, BK)
NV = tl.cdiv(DV, BV)
NT = tl.cdiv(S, chunk_size)
@T.prim_func
def chunk_retention_fwd(
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore
):
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
log_decay = T.alloc_var('float32')
log_decay = T.log2(1 - T.exp2(-5. - 1. * i_h)) # Head-specific log decay
q = T.alloc_shared([chunk_size, BK], dtype)
k = T.alloc_shared([chunk_size, BK], dtype)
v = T.alloc_shared([chunk_size, BV], dtype)
h = T.alloc_fragment([BK, BV], accum_dtype)
h_shared = T.alloc_shared([BK, BV], dtype)
s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)
T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)
for i in T.Pipelined(0, NT):
for row, col in T.Parallel(chunk_size, BK):
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
T.gemm(q, k, s, clear_accum=True, transpose_B=True)
for row, col in T.Parallel(chunk_size, chunk_size):
s_shared[row,
col] = T.if_then_else(row >= col, s[row, col] * T.exp2(
(row - col) * log_decay), 0)
T.copy(h, h_shared)
T.gemm(q, h_shared, o, clear_accum=True)
for row, col in T.Parallel(chunk_size, BV):
o[row, col] = T.exp2((row + 1) * log_decay) * o[row, col]
T.gemm(s_shared, v, o)
for row, col in T.Parallel(chunk_size, BV):
v[row, col] = v[row, col] * T.exp2((chunk_size - row - 1) * log_decay)
for row, col in T.Parallel(BK, BV):
h[row, col] = T.exp2(chunk_size * log_decay) * h[row, col]
T.copy(
o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h,
i_v * BV:(i_v + 1) * BV])
T.gemm(k, v, h, transpose_A=True)
return chunk_retention_fwd
def postprocess(o):
return o if o.size(0) == 1 else o.sum(0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--B', type=int, default=8, help='Batch size')
parser.add_argument('--S', type=int, default=4096, help='Seq len')
parser.add_argument('--H', type=int, default=32, help='Num heads')
parser.add_argument('--D', type=int, default=128, help='Head dim')
args = parser.parse_args()
B, S, H, D = args.B, args.S, args.H, args.D
total_flops = 2.0 * B * S * S * H * D # causal
q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
kernel = chunk_retention_fwd_kernel(B, S, H, D, D)
t = do_bench(lambda: postprocess(kernel(q, k, v)), warmup=25, rep=100)
print(f'Tilelang latency: {t:.3f} ms')
print(f'Tilelang TFLOPs: {total_flops/t * 1e-9}')
if __name__ == '__main__':
main()
import argparse
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[4])
def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N):
qk_shape = [batch, seq_len, heads, dim_qk]
v_shape = [batch, seq_len, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
Q: T.Tensor(qk_shape, dtype),
K: T.Tensor(qk_shape, dtype),
V: T.Tensor(v_shape, dtype),
mask: T.Tensor([heads, seq_len, seq_len], dtype),
Output: T.Tensor(v_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
mask_shared = T.alloc_shared([block_M, block_N], dtype)
acc_o_shared = T.alloc_shared([block_M, dim_v], dtype)
mask_local = T.alloc_fragment([block_M, block_N], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_1 = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_shared = T.alloc_shared([block_M, block_N], dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
abs_sum = T.alloc_fragment([block_M], accum_dtype)
r_wo_clamp = T.alloc_fragment([block_M], accum_dtype)
r = T.alloc_fragment([block_M], accum_dtype)
r_new = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
mask_shared: tilelang.layout.make_swizzled_layout(mask_shared),
acc_s_shared: tilelang.layout.make_swizzled_layout(acc_s_shared),
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(r, 0)
T.fill(r_new, 0)
T.fill(r_wo_clamp, 0)
T.fill(acc_o, 0)
loop_range = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(mask[by, bx * block_M:(bx + 1) * block_M, k * block_N:(k + 1) * block_N],
mask_shared)
T.copy(mask_shared, mask_local)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = acc_s[i, j] * mask_local[i, j]
T.copy(acc_s, acc_s_shared)
T.copy(acc_s_shared, acc_s_1)
T.reduce_abssum(acc_s_1, abs_sum, dim=1)
for i in T.Parallel(block_M):
r_wo_clamp[i] = r_wo_clamp[i] + abs_sum[i]
for i in T.Parallel(block_M):
r_new[i] = T.max(r_wo_clamp[i], 1)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] = T.if_then_else(k > 0, acc_o[i, j] * r[i] / r_new[i], acc_o[i, j])
T.copy(r_new, r)
for i, j in T.Parallel(block_M, block_N):
acc_s_1[i, j] = acc_s_1[i, j] / r_new[i]
T.copy(acc_s_1, acc_s_cast)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.copy(acc_o, acc_o_shared)
T.copy(acc_o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, mask):
qk = torch.einsum('bqhd,bkhd->bhqk', Q, K)
qkm = qk * mask
r = qkm.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1.0)
o = torch.einsum('bhqk,bkhd->bqhd', qkm / r, V)
return o.to(dtype=torch.float16)
def ref_inference(Q, K, V, prev_kv, prev_scale, decay):
# Q : batch, seqlen, num_heads, head_dimqk
# K : batch, seqlen, num_heads, head_dimqk
# V : batch, seqlen, num_heads, head_dimv
# prev_kv : batch, num_heads, head_dimv, head_dimqk
# prev_scale : num_heads, 1, 1
# decay : num_heads, 1, 1
seqlen = V.size(1)
num_heads = V.size(2)
assert seqlen == 1, "Only support seqlen == 1"
qr = Q.transpose(1, 2).contiguous() # batch, num_heads, 1, head_dimqk
kr = K.transpose(1, 2).contiguous() # batch, num_heads, 1, head_dimqk
v = V.transpose(1, 2).transpose(2, 3).contiguous() # batch, num_heads, head_dimv, 1
kv = kr * v # batch, num_heads, head_dimv, head_dimqk
scale = prev_scale * decay + 1 # num_heads, 1, 1
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(
num_heads, 1, 1) + kv / scale.sqrt().view(num_heads, 1, 1)
output = torch.sum(qr * kv, dim=3)
return output
def retnet_inference(batch, heads, dim_qk, dim_v, block_M):
qk_shape = [batch, 1, heads, dim_qk]
v_shape = [batch, 1, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
Q: T.Tensor(qk_shape, dtype),
K: T.Tensor(qk_shape, dtype),
V: T.Tensor(v_shape, dtype),
prev_kv: T.Tensor([batch, heads, dim_v, dim_qk], dtype),
prev_scale: T.Tensor([heads], dtype),
decay: T.Tensor([heads], dtype),
Output: T.Tensor([batch, heads, dim_v], dtype),
):
with T.Kernel(T.ceildiv(dim_v, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_local = T.alloc_fragment([1, dim_qk], dtype)
K_local = T.alloc_fragment([dim_qk], dtype)
V_local = T.alloc_fragment([block_M], dtype)
kv_local = T.alloc_fragment([block_M, dim_qk], accum_dtype)
prev_kv_local = T.alloc_fragment([block_M, dim_qk], dtype)
prev_scale_local = T.alloc_fragment([1], dtype)
decay_local = T.alloc_fragment([1], accum_dtype)
# scale_local = T.alloc_fragment([1], accum_dtype)
qkv_local = T.alloc_fragment([block_M, dim_qk], accum_dtype)
o_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
prev_scale_local: T.Layout(prev_scale_local.shape, lambda i: i),
decay_local: T.Layout(decay_local.shape, lambda i: i),
# scale_local: T.Layout(scale_local.shape, lambda i : i),
kv_local: T.Fragment(kv_local.shape, lambda i, j: j // 8),
})
T.copy(Q[bz, 0, by, :], Q_local)
T.copy(K[bz, 0, by, :], K_local)
T.copy(V[bz, 0, by, bx * block_M:(bx + 1) * block_M], V_local)
T.copy(prev_kv[bz, by, bx * block_M:(bx + 1) * block_M, :], prev_kv_local)
prev_scale_local[0] = prev_scale[by]
decay_local[0] = decay[by]
for i, j in T.Parallel(block_M, dim_qk):
kv_local[i, j] = K_local[j] * V_local[i]
for i, j in T.Parallel(block_M, dim_qk):
kv_local[i, j] += kv_local[i, j]
for i, j in T.Parallel(block_M, dim_qk):
kv_local[i, j] += prev_kv_local[i, j] * T.sqrt(prev_scale[by]) * decay[by]
for i, j in T.Parallel(block_M, dim_qk):
kv_local[i, j] = kv_local[i, j] / T.sqrt(prev_scale[by] * decay[by] + 1)
for i, j in T.Parallel(block_M, dim_qk):
qkv_local[i, j] = Q_local[0, j] * kv_local[i, j]
T.reduce_sum(qkv_local, o_local, dim=1)
T.copy(o_local, Output[bz, by, bx * block_M:(bx + 1) * block_M])
return main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size')
parser.add_argument('--h', type=int, default=10, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=4096, help='Context size')
parser.add_argument('--dim_qk', type=int, default=256, help='Head dimension')
parser.add_argument('--dim_v', type=int, default=448, help='Head dimension')
args = parser.parse_args()
BATCH, H, N_CTX, dim_qk, dim_v = args.batch, args.h, args.n_ctx, args.dim_qk, args.dim_v
total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v)
BLOCK_M = 64
BLOCK_N = 64
kernel = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
ins = profiler._get_inputs()
ref_outs = ref_program(*ins)
lib_outs = kernel(*ins)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(n_warmup=10, n_repeat=10, profiler="torch")
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment