Commit 2add9fa3 authored by wangkx1's avatar wangkx1
Browse files

add tilelang

parent f5bc26c2
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
batch: int = 8,
heads: int = 32,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
# ruff: noqa
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
import argparse
import torch
from einops import rearrange, repeat
from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim) ** 0.5 # log2(e)
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0)
scores = scores * scale
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [UQ, heads, dim]
k_shape = [UKV, heads, dim]
v_shape = [UKV, heads, dim]
o_shape = [UQ, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], T.int32),
cu_seqlens_k: T.Tensor([batch_size + 1], T.int32),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype, "shared")
K_shared = T.alloc_shared([block_N, dim], dtype, "shared")
V_shared = T.alloc_shared([block_N, dim], dtype, "shared")
O_shared = T.alloc_shared([block_M, dim], dtype, "shared")
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
batch_idx = bz
head_idx = by
q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d]
else:
Q_shared[i, d] = 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Q * K
for i, d in T.Parallel(block_N, dim):
if k * block_N + i < k_current_seqlen:
K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d]
else:
K_shared[i, d] = 0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= k * block_N + j)
and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype),
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0
)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
for i, d in T.grid(block_N, dim):
if k * block_N + i < v_current_seqlen:
V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d]
else:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
return main
def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
tilelang.testing.set_random_seed(0)
causal = False
if causal:
total_flops *= 0.5
dtype = torch.float16
device = torch.device("cuda")
window_size = (-1, -1)
q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0] # unpadded query length
UK = k_unpad.shape[0] # unpadded key length
UKV = k_unpad.shape[0] # unpadded query key length
kernel = flashattn(batch, UQ, UKV, heads, dim, causal)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
import flash_attn
fla_out_unpad = flash_attn.flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
)
fla_out = output_pad_fn(fla_out_unpad)
torch.testing.assert_close(out, fla_out, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_len", type=int, default=2048, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim)
import tilelang.testing
import example_gqa_bwd
import example_gqa_bwd_wgmma_pipelined
import example_mha_bwd_bshd
import example_mha_bwd_bhsd
import example_mha_fwd_bhsd_wgmma_pipelined
import example_gqa_fwd_bshd
import example_mha_fwd_bshd
import example_gqa_fwd_bshd_wgmma_pipelined
import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen
import example_mha_bwd_bshd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen
@tilelang.testing.requires_cuda
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()
@tilelang.testing.requires_cuda
def test_example_gqa_bwd():
example_gqa_bwd.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_bwd_wgmma_pipelined():
example_gqa_bwd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd_bshd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(
BATCH=1,
H=16,
N_CTX=512,
D_HEAD=64,
causal=False,
)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_fwd_bhsd_wgmma_pipelined():
example_mha_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
def test_example_mha_fwd_bhsd():
example_mha_fwd_bhsd.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_fwd_bshd_wgmma_pipelined():
example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256)
@tilelang.testing.requires_cuda
def test_example_mha_fwd_bshd():
example_mha_fwd_bshd.main(batch=1, seq_len=256)
@tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main(batch=4, heads=16, seq_len=512, dim=64)
if __name__ == "__main__":
tilelang.testing.main()
# ruff: noqa
import torch
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
return padding_mask
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import itertools
from functools import lru_cache
from typing import Tuple, Dict
torch.random.manual_seed(0)
def get_configs():
block_N = [64, 128]
block_H = [64]
num_split = [1, 2, 4, 8]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
@lru_cache(maxsize=1)
def get_heuristic_config() -> Tuple[Dict, int]:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else:
cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
return cfg, sm_version
# TODO(lei): fix warp specialized and tma lower pass
def get_pass_configs():
return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim]
shape_o = [batch, heads, dim]
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // groups
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
K_shared,
)
T.copy(
mask[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
],
mask_local,
)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
V_shared,
)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_logsum_local = T.alloc_fragment([128], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_fragment([128], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.serial(num_split):
for j in T.Parallel(128):
lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
for j in T.Parallel(128):
lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
for j in T.Parallel(128):
scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
# Note: Pay attention to dim and the number of threads in Parallel
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[i]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn(Q, K, V, mask, Output)
if num_split > 1:
return flashattn_gqa_decode_split
else:
return flashattn_gqa_decode_no_split
def ref_program(query, key, value, mask, glse, Output_partial):
# """
# Inputs:
# - query (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, groups, dim]
# - value (Tensor): [batch, seqlen_kv, groups, dim]
# - mask (Tensor): [batch, seqlen_kv, groups]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = query.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
if mask is not None:
mask = rearrange(mask, "b s h -> b h s")
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def flash_split_ref(Q, K, V, mask):
num_split = 16
batch = Q.size(0)
nheads = Q.size(1)
groups = K.size(2)
dim = Q.size(-1)
block_N = 32
seqlen_kv = K.size(1)
num_head_groups = nheads // groups
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum(
"bghd,bkhd->bghk",
Q_,
K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
if mask is not None:
mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
mask_local = rearrange(mask_local, "b s h -> b h s")
mask_local = mask_local.unsqueeze(1)
acc_s = acc_s.masked_fill(mask_local == 0, float("-inf"))
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
acc_o *= scores_scale[:, :, :, None]
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
"bghk,bkhd->bghd",
acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o_out = rearrange(acc_o, "b g h d->b (h g) d")
logsum_out = rearrange(logsum, "b g h->b (h g)")
acc_o_out /= logsum_out[:, :, None]
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)")
gacc_o[ks, :, :, :] = acc_o_out
glogsum[ks, :, :] = logsum_out
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
def reduce_ref(Q, K, V, mask, glse, Output_partial):
num_split = 16
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads]
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks]
scale = torch.exp2(lse - lse_logsum) # [batch, heads]
o += Output_partial[:, :, ks, :] * scale[:, :, None]
return o.to(torch.float16)
def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None):
glse_, Output_partial_ = flash_split_ref(Q, K, V, mask)
return reduce_ref(Q, K, V, mask, glse_, Output_partial_)
def print_red_warning(msg):
print(f"\033[91m{msg}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print_red_warning(f"{name} Error: {diff}")
if assert_:
raise AssertionError(f"{name} Error: {diff}")
else:
if print_:
print(f"passed: {name} diff={diff}")
def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False):
batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops
if not tune:
config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
split = config["num_split"]
glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
print(o)
print(o_ref)
assert_similar(o, o_ref, name="o_ref")
assert_similar(o, o_ref_split, name="o_ref_split")
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
import torch
import triton
import triton.language as tl
import math
import argparse
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
torch.manual_seed(0)
tilelang.disable_cache()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@triton.jit
def _fwd_inner(
q,
k_ptrs,
v_ptrs,
s_ptrs,
m_i,
l_i,
acc,
offs_h,
mask_h,
offs_n,
seqlen,
softmax_scale,
lo,
hi,
stride_kt,
stride_vt,
stride_sh,
stride_sn,
BLOCK_N: tl.constexpr,
):
"""Inner loop computation for attention"""
for blk_idx in tl.range(lo, hi):
start_n = blk_idx * BLOCK_N
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen)
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen)
qk = tl.dot(q, k)
qk *= softmax_scale
qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9)
row_max = tl.max(qk, 1)
tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h)
m_ij = tl.maximum(m_i, row_max)
qk -= m_ij[:, None]
p = tl.math.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
m_i = m_ij
acc *= alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
return m_i, l_i, acc
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]],
key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"],
)
@triton.jit
def _fwd_kernel_varlen(
Q, # [token_q = b, h_q, dim]
K, # [token_k, h_kv, dim]
V,
O,
S,
s_aux,
softmax_scale,
cu_seqlens_k,
stride_qt,
stride_qh,
stride_qd,
stride_kt,
stride_kh,
stride_kd,
stride_vt,
stride_vh,
stride_vd,
stride_ot,
stride_oh,
stride_od,
stride_sb,
stride_sh,
stride_sn, # bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_z = tl.program_id(0)
off_h_for_kv = tl.program_id(1)
off_h_q = off_h_for_kv * gqa_group_size
cu_k_start = tl.load(cu_seqlens_k + off_z)
cu_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_k_end - cu_k_start
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh
K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh
V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh
O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh
S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh
mask_h = offs_h < gqa_group_size
q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
if s_aux is not None:
sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32)
l_i = tl.zeros([BLOCK_H], dtype=tl.float32)
m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink
else:
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N)
m_i, l_i, acc = _fwd_inner(
q,
k_ptrs,
v_ptrs,
S_ptrs,
m_i,
l_i,
acc,
offs_h,
mask_h,
offs_n,
seqlen_k,
softmax_scale,
lo,
hi,
stride_kt,
stride_vt,
stride_sh,
stride_sn,
BLOCK_N,
)
if s_aux is not None:
sink = tl.math.exp(sink - m_i)
l_i = l_i + sink
acc = acc / l_i[:, None]
else:
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
for blk_idx in tl.range(lo, hi):
s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h)
s = tl.exp(s - m_i) / l_i
tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h)
acc = acc.to(O.dtype.element_ty)
tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None])
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(
batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // k_heads
valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], T.float32)
T.annotate_layout(
{
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
cur_start_k = cu_seqlens_k[bid]
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T.copy(scores_max, S_shared[:, k])
# scores_scale is alpha in triton
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h]
# T.copy(S_shared, S_fragment)
# for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
# S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
# T.copy(S_fragment, S_shared)
T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S)
# TODO: split version
return flashattn_gqa_decode_no_split
def flash_attn_with_attn_pool_decode_tilelang(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
tl_kernel=None,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
if use_per_kv_head_sparse_index:
S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1))
return O_tl, S_tl
def flash_attn_with_attn_pool_decode(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
BLOCK_D = head_size
BLOCK_N = block_size
BLOCK_H = 64
O = torch.zeros_like(Q)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device)
def grid(META):
return (batch, k_h)
with torch.cuda.device(Q.device.index):
_fwd_kernel_varlen[grid](
Q,
K,
V,
O,
S,
s_aux,
softmax_scale,
cu_seqlens_k,
*Q.stride(),
*K.stride(),
*V.stride(),
*O.stride(),
*S.stride(),
gqa_group_size,
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
BLOCK_D=BLOCK_D,
)
if use_per_kv_head_sparse_index:
S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1))
return O, S
def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen # Use as max sequence length
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
print(f"Actual max_seqlen_k: {max_seqlen_k}")
print(f"q_decode shape: {q_decode.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
v_padded_list = []
for i in range(batch_size):
actual_k_len = k_seqlens[i]
# Extract and pad k, v for this batch
k_start = cu_seqlens_k[i]
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
k_padded_list.append(k_padded)
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
print(f"q_expanded shape: {q_expanded.shape}")
print(f"k_padded_batched shape: {k_padded_batched.shape}")
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
print(f"O_torch shape: {O_torch.shape}")
print(f"S_triton shape: {S_triton.shape}")
print(f"S_tilelang shape: {S_tilelang.shape}")
print(f"attn_score_pooled shape: {attn_score_pooled.shape}")
# Compare results
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
print("✅ All tests passed!")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def speed_benchmark_decode_comparison(args):
"""Speed benchmark for decode kernel"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
print(f" Batch size: {batch_size}")
print(f" Q heads: {q_heads}, KV heads: {kv_heads}")
print(f" Max K sequence length: {max_k_seqlen}")
print(f" Head size: {head_size}")
print(f" Block size: {block_size}")
print(f" Data type: {dtype}")
print(f" Variable lengths: {args.test_varlen}")
print(f" s_aux attention: {args.test_sink}")
print()
# Generate input data
if args.test_varlen:
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
else:
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
print(f" Total K tokens: {total_k_tokens}")
print(f" Actual max K seq len: {max_seqlen_k}")
if args.test_varlen:
print(f" K sequence lengths: {k_seqlens.tolist()}")
# Warmup
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
tilelang_time = do_bench(
flash_attn_with_attn_pool_decode_tilelang,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
False,
tl_kernel,
)
print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms")
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(
flash_attn_with_attn_pool_decode,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = False
args.dtype = T.float16
args.num_split = 1
if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
import torch
import math
import argparse
import tilelang
import tilelang.language as T
from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench
torch.manual_seed(0)
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(
batch,
heads,
k_heads,
max_seqlen_kv,
total_seqlen_k,
dim,
has_sink,
page_block_size,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128,
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // k_heads
assert page_block_size >= block_N and page_block_size % block_N == 0, (
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
s_aux_shared = T.alloc_shared([block_H], T.float32)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
cur_start_k = cu_seqlens_k[bid]
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T.copy(scores_max, S_shared[:, k])
# scores_scale is alpha in triton
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S)
# TODO: split version
return flashattn_gqa_decode_no_split
def flash_attn_with_attn_pool_decode_tilelang(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
tl_kernel=None,
block_table: torch.Tensor = None,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table)
if use_per_kv_head_sparse_index:
S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1))
return O_tl, S_tl
def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen # Use as max sequence length
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
print(f"Actual max_seqlen_k: {max_seqlen_k}")
print(f"q_decode shape: {q_decode.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
v_padded_list = []
for i in range(batch_size):
actual_k_len = k_seqlens[i]
# Extract and pad k, v for this batch
k_start = cu_seqlens_k[i]
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
k_padded_list.append(k_padded)
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
print(f"q_expanded shape: {q_expanded.shape}")
print(f"k_padded_batched shape: {k_padded_batched.shape}")
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
print(f"O_torch shape: {O_torch.shape}")
print(f"S_triton shape: {S_triton.shape}")
print(f"S_tilelang shape: {S_tilelang.shape}")
print(f"attn_score_pooled shape: {attn_score_pooled.shape}")
# Compare results
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
print("✅ All tests passed!")
def speed_benchmark_decode_comparison(args):
"""Speed benchmark for decode kernel"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
print(f" Batch size: {batch_size}")
print(f" Q heads: {q_heads}, KV heads: {kv_heads}")
print(f" Max K sequence length: {max_k_seqlen}")
print(f" Head size: {head_size}")
print(f" Block size: {block_size}")
print(f" Data type: {dtype}")
print(f" Variable lengths: {args.test_varlen}")
print(f" s_aux attention: {args.test_sink}")
print()
# Generate input data
if args.test_varlen:
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
else:
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
print(f" Total K tokens: {total_k_tokens}")
print(f" Actual max K seq len: {max_seqlen_k}")
if args.test_varlen:
print(f" K sequence lengths: {k_seqlens.tolist()}")
# Warmup
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
tilelang_time = do_bench(
flash_attn_with_attn_pool_decode_tilelang,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
False,
tl_kernel,
block_table,
)
print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms")
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(
flash_attn_with_attn_pool_decode,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument("--page_block_size", type=int, default=128, help="Page block size")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = True
args.dtype = T.float16
args.num_split = 1
if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from functools import partial
num_split = 4
@tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim]
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
K: T.Tensor(shape_kv, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
mid: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared)
# TODO: Handle causal split case
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape_kv, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
mid = bx
hid = by % heads
bid = by // heads
sid = bz
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle causal split case
loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N))
if is_causal
else T.ceildiv((seqlen_kv // num_split), block_N)
)
for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True)
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_q, dtype),
):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype)
po_shared = T.alloc_shared([block_M, dim], dtype)
o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype)
o_shared = T.alloc_shared([block_M, dim], dtype)
lse_local = T.alloc_fragment([num_split, block_M], dtype)
lse_local_split = T.alloc_fragment([block_M], accum_dtype)
lse_logsum_local = T.alloc_fragment([block_M], accum_dtype)
lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout(
{
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
o_shared: tilelang.layout.make_swizzled_layout(o_shared),
po_shared: tilelang.layout.make_swizzled_layout(po_shared),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
T.copy(
glse[
bz,
by,
:,
bx * block_M : (bx + 1) * block_M,
],
lse_local,
)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M):
lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i])
for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2):
T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True)
T.copy(po_shared, po_local)
for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
for i in T.Parallel(block_M):
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func
def flashattn_mha_inference(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Tensor(shape_q, dtype),
):
flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output)
return flashattn_mha_inference
def ref_program(Q, K, V, glse, Output_partial, causal):
assert causal is False
dim = Q.size(-1)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def reduce_ref(Q, K, V, glse, Output_partial, causal):
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks, :]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks, :]
scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q]
o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2)
return o.to(torch.float16)
def flash_split_ref(Q, K, V, causal):
# [batch, seqlen_q, heads, dim]
batch = Q.size(0)
block_M = Q.size(1)
nheads = Q.size(2)
dim = Q.size(3)
block_N = 128
seqlen_kv = K.size(1)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float)
Q_ = Q * scale
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum(
"bqhd,bkhd->bhqk",
Q_,
K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max)
acc_o *= scores_scale[:, :, :, None].transpose(1, 2)
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum(
"bhqk,bkhd->bqhd",
acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2)
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32
kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_fn = partial(ref_program, causal=causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01)
print("All checks passed!")
latency = profiler.do_bench(ref_fn, warmup=500)
print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(n_warmup=10, n_repeat=10)
print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
main()
import tilelang.testing
import example_gqa_decode
import example_mha_inference
# TODO(lei): fix the correctness of gqa decode on sm90
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_example_gqa_decode():
example_gqa_decode.main()
def test_example_example_mha_inference():
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
if __name__ == "__main__":
tilelang.testing.main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(
d_hidden,
d_expert,
n_shared_experts,
dtype,
num_tokens,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert * n_shared_experts
# Tensors: Note that input shape is reshape to (num_tokens, dhidden)
input_shape = (num_tokens, dhidden)
shared_W_gate_shape = (dexpert, dhidden)
shared_W_up_shape = (dexpert, dhidden)
shared_W_down_shape = (dhidden, dexpert)
accum_type = T.float32
@T.prim_func
def kernel_shared(
input: T.Tensor(input_shape, dtype), # type: ignore
shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore
shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore
shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore
up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
# Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
W_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
# Shared experts: no need to check expert_indices
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
T.use_swizzle(10)
T.clear(gate_logits_local)
T.clear(up_logits_local)
# Parallel for gate and up matmul
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(input[bx * block_token, k * block_dhidden], input_shared)
T.copy(shared_W_gate[by * block_dexpert, k * block_dhidden], W_gate_shared)
T.copy(shared_W_up[by * block_dexpert, k * block_dhidden], W_up_shared)
T.gemm(input_shared, W_gate_shared, gate_logits_local, transpose_B=True)
T.gemm(input_shared, W_up_shared, up_logits_local, transpose_B=True)
# Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
T.use_swizzle(10)
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(up_logits[bx * block_token, k * block_dexpert], up_logits_shared)
T.copy(shared_W_down[by * block_dhidden, k * block_dexpert], W_down_shared)
T.gemm(up_logits_shared, W_down_shared, output_local, transpose_B=True)
T.copy(output_local, output[bx * block_token, by * block_dhidden])
return kernel_shared
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(
d_hidden,
d_expert,
n_routed_experts,
dtype,
group_sum,
group_count,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=None,
):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert
n_routed_experts = n_routed_experts
# Group info
# group_sum = sum(group_sizes_list)
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M = math.ceil(group_sum / block_token) + group_count
accum_dtype = T.float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape = (group_sum, dhidden)
intermediate_shape = (group_sum, dexpert)
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = group_sum
group_sizes_shape = n_routed_experts
@T.prim_func
def kernel(
input: T.Tensor(input_shape, dtype), # type: ignore
routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local)
T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(
input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared,
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_gate[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_gate_shared,
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True)
T.copy(
routed_expert_up[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared,
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert):
if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
# Step 2: Compute down logits
with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(
up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
up_logits_shared,
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_down[
cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
],
routed_expert_down_shared,
coalesced_width=coalesced_width,
)
T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows:
output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel
class Expert(nn.Module):
def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.device = torch.device("cuda")
self.W_gate_weight = gate.t().contiguous().to(self.device)
self.W_up_weight = up.t().contiguous().to(self.device)
self.W_down_weight = down.t().contiguous().to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(x @ self.W_gate_weight)
out = (gate * (x @ self.W_up_weight)) @ self.W_down_weight
return out
class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights["router.weight"].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoE(nn.Module):
def __init__(
self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128
):
super().__init__()
self.config = config
self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel
self.padding_M = padding_M
self.experts = nn.ModuleList(
[
Expert(
config,
gate=weights[f"experts.{i}.0.weight"],
up=weights[f"experts.{i}.1.weight"],
down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert(
config=config,
gate=weights["shared_experts.0.weight"],
up=weights["shared_experts.1.weight"],
down=weights["shared_experts.2.weight"],
d_expert=shared_expert_dim,
).to(self.device)
self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
)
self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device
)
self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device,
)
self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
batch_size, seq_len, hidden_dim = x.shape
expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1)
x_flat = x.view(-1, hidden_dim)
# Prepare for grouped GEMM
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
# counts = flat_expert_indices.bincount()
tokens_per_expert = counts.cumsum()
# tokens_per_expert = torch.cumsum(counts, dim=0)
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
# Get stacked expert tokens and expert weights
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x_flat[exp_token_idxs]
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128
M = (
math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
+ self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M):
m_start_padded = bx * block_token
for i in range(self.config["n_routed_experts"]):
if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution
shared_stream = torch.cuda.Stream()
routed_stream = torch.cuda.default_stream()
torch.cuda.synchronize()
with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM
self.routed_kernel(
self.stacked_expert_tokens,
self.stacked_expert_w_gate,
self.stacked_expert_w_up,
self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# Scatter reduce
self.expert_cache = torch.scatter_reduce(
self.expert_cache,
0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed,
reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream):
self.shared_kernel(
x_flat,
self.shared_expert.W_gate_weight,
self.shared_expert.W_up_weight,
self.shared_expert.W_down_weight,
self.up_logits_shared,
self.expert_output_shared,
)
shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize()
return shared_output + routed_output
def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
DeepSeek-style Mixture of Experts using Tilelang.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
dtype_str = T.float16
shared_kernel = moe_forward_tilelang_shared(
config["d_hidden"],
config["d_expert"],
config["n_shared_experts"],
dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"],
)
routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"],
config["d_expert"],
config["n_routed_experts"],
dtype=dtype_str,
group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
group_count=config["n_routed_experts"],
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=2,
)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
output = moe(input_tensor)
return output
def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192):
config = {
"dhidden": d_hidden,
"dexpert": d_expert,
"nroutedexperts": n_routed_experts,
"nsharedexperts": n_shared_experts,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394,
}
data = generate_input(**config)
torch.cuda.synchronize()
ref_output = ref_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
tilelang_output = custom_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2)
print("✅ Tilelang and Torch match")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(self.W_gate(x))
out = self.W_down(gate * self.W_up(x))
return out
class MoEGateTorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.W_g(x)
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoETorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_expert(x)
expert_indices, expert_scores = self.gating_network(x)
batch_size, seq_len, hidden_dim = x.shape
orig_shape = x.shape
x_flat = x.view(-1, hidden_dim)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1, 1)
routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights)
routed_output = routed_output_flat.view(*orig_shape)
return routed_output + shared_output
@torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"]))
# test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"]))
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
tokens_per_expert = counts.cumsum()
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
expert = self.experts[expert_id]
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idxs]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
return expert_cache
def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
Reference implementation of DeepSeek-style Mixture of Experts using PyTorch.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_dim]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
num_experts = config["n_routed_experts"]
moe = MoETorch(config)
# Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
for i in range(num_experts):
gate_proj_weight = weights[f"experts.{i}.0.weight"]
up_proj_weight = weights[f"experts.{i}.1.weight"]
down_proj_weight = weights[f"experts.{i}.2.weight"]
# Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
output = moe(input_tensor)
return output
# Input generation for the reference code
def generate_input(
dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden
d_expert = dexpert
n_routed_experts = nroutedexperts
n_shared_experts = nsharedexperts
n_experts_per_token = nexpertspertoken
batch_size = bs
seq_len = seqlen
config = {
"d_hidden": d_hidden,
"d_expert": d_expert,
"n_routed_experts": n_routed_experts,
"n_shared_experts": n_shared_experts,
"n_experts_per_token": n_experts_per_token,
"batch_size": batch_size,
"seq_len": seq_len,
}
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
num_experts = n_routed_experts
expert_dim = d_expert
weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
# Initialize router weights
weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts):
weights[f"experts.{i}.0.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.1.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.2.weight"] = torch.randn(
(expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
weights["shared_experts.0.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.1.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.2.weight"] = torch.randn(
(expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
return (input_tensor, weights, config)
def clone_data(data):
"""
Recursively goes through data and clones all tensors.
"""
if isinstance(data, tuple):
return tuple(clone_data(x) for x in data)
elif isinstance(data, list):
return [clone_data(x) for x in data]
elif isinstance(data, dict):
return {k: clone_data(v) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
return data.clone()
else:
return data
import tilelang.testing
import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main(
d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024
)
if __name__ == "__main__":
tilelang.testing.main()
# Gated Delta Net (GDN) kernel implementation with TileLang
## Requirement
- TileLang: `0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1`
- Triton: `3.3.0` (used for comparison)
- FLA: commit `f03cb3ae` (used for comparison)
## Get started
The [chunk_delta_h](common/chunk_delta_h.py) implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the TileLang optimization.
## Acknowledgments
This kernel was developed by Yu Cheng and Zhengju Tang following in-depth discussions with Xiaomi's LLM-Core Team (MiMo).
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
print(tilelang.__file__, flush=True)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
from test_utils import assert_similar
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
# Note: G should be in logspace and do chunkwise cumsum
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return dh, dh0, dv2
def torch_chunk_gated_delta_rule_bwd_dhu(
Q: torch.Tensor,
K: torch.Tensor,
W: torch.Tensor,
G: torch.Tensor,
h0: torch.Tensor,
dht: torch.Tensor,
dO: torch.Tensor,
dv: torch.Tensor,
scale: float,
use_g: bool,
use_initial_state: bool,
use_final_state_gradient: bool,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
B, S, H, DK = Q.shape
DV = dv.shape[-1]
block_S = 64
BS = S // block_S
dh, dh0, dv2 = (
torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
torch.empty((B, H, DK, DV), dtype=state_dtype),
torch.empty((B, S, H, DV), dtype=output_dtype),
)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
if use_final_state_gradient:
dh_tmp = dht.clone().to(accum_dtype)
else:
dh_tmp = torch.zeros_like(dht).to(accum_dtype)
for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g:
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
else:
dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
for i_s2 in range(block_S):
for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale
W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3))
torch.backends.cuda.matmul.allow_tf32 = False
if use_initial_state:
dh0 = dh_tmp[:, :, :, :]
else:
dh0 = torch.zeros_like(dh_tmp[:, :, :, :])
print(dh0.dtype)
return dh, dh0, dv2
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_bwd_dhu(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
# kernel config
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
# Should support cu_seqlen
BS = S // block_S
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
W_shape = (B, S, H, DK)
G_shape = (B, S, H)
h0_shape = (B, H, DK, DV)
dht_shape = (B, H, DK, DV)
dO_shape = (B, S, H, DV)
dv_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dh0_shape = (B, H, DK, DV)
dv2_shape = (B, S, H, DV)
@T.prim_func
def kernel(
# Input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output
dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype)
b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype)
b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32)
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32)
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_last_local_exp = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared")
G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype)
Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)
T.use_swizzle(10)
T.annotate_layout(
{
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
}
)
if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment)
else:
T.clear(b_dh_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# The gradient should be stored in the reverse order
i_s_inv = T.ceildiv(S, block_S) - i_s - 1
# Store the updated dh
T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Update dv
T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g:
T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0])
for i_s2 in T.Parallel(block_S):
G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2])
for i_s2, i_v in T.Parallel(block_S, block_DV):
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then():
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else():
dv_fragment[i_s2, i_v] = 0
T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv
T.copy(dv_fragment, dv_shared)
T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
# Update dh
T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment)
if use_g:
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] *= G_last_local_exp[0]
T.copy(Q_shared, Q_fragment)
for i_s2 in T.Parallel(block_S):
G_fragment_exp[i_s2] = T.exp(G_shared[i_s2])
for i_s2, i_k in T.Parallel(block_S, DK):
# Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale
else:
T.copy(Q_shared, Q_fragment)
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale
# Get transpose of Q_fragment to meet tf32 gemm requirement
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
T.copy(dO_fragment_t, dO_shared_t)
T.clear(b_dh_fragment_1)
T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True)
T.clear(b_dh_fragment_2)
T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True)
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel
def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name):
try:
torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh_0 and dh_1 passed for {name}")
except Exception as e:
print(f"{name} dh_0 and dh_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh0_0 and dh0_1 passed for {name}")
except Exception as e:
print(f"{name} dh0_0 and dh0_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dv2_0 and dv2_1 passed for {name}")
except Exception as e:
print(f"{name} dv2_0 and dv2_1 are not close for {name}")
print(e, end="\n\n")
close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}"
)
error_num += 1
close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=64,
threads=256,
num_stages=0,
use_torch=False,
):
Q, K, W, G, h0, dht, dO, dv = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref, dh0_ref, dv2_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
# fla ref
print("fla running...", flush=True)
if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
else:
G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
# tilelang
print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
# kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms")
print(f"tilelang time: {tilelang_time} ms")
assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh")
assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0")
assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2")
# torch ref
if use_torch:
print("torch running...", flush=True)
if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q,
K,
W,
G,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q,
K,
W,
None,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh")
assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0")
assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2")
assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh")
assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0")
assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def main():
DK = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=128,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
scale=DK**-0.5,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=32,
threads=128,
num_stages=1,
use_torch=False,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from test_utils import assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# you can comment out the following function.
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read()
# code = cuda_code
# return code
torch.random.manual_seed(0)
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
W = F.normalize(W, dim=-1, p=2)
U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
U = F.normalize(U, dim=-1, p=2)
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
return K, W, U, G, initial_state
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
state_dtype,
):
BS = S // chunk_size
h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return h, final_state, V_new
def get_configs():
import itertools
block_DK = [32, 64, 128]
block_DV = [32, 64, 128]
threads = [128, 256]
num_stages = [1, 2, 3]
_configs = list(itertools.product(block_DK, block_DV, threads, num_stages))
configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=3, rep=5)
@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def tilelang_chunk_gated_delta_rule_fwd_h(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
# kernel config
block_DK=64,
block_DV=32,
threads=128,
num_stages=1,
):
block_S = chunk_size
BS = S // block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
U_shape = (B, S, H, DV)
G_shape = (B, S, H)
h_shape = (B, BS, H, DK, DV)
initial_state_shape = (B, H, DK, DV)
final_state_shape = (B, H, DK, DV)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype)
b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout(
{
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared),
}
)
T.use_swizzle(10)
if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment)
else:
T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Recurrence
T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S
T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared)
T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
# Save V_new
if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared)
T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g
if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
for i_s2, i_v in T.Parallel(block_S, block_DV):
G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh]
T.copy(G_shared, G_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
(G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695
)
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0]
# Update intermediate results
T.copy(V_new_fragment, V_new_shared)
T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True)
T.copy(b_h_fragment, b_h_shared)
# Save final state
if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=0,
):
K, W, U, G, initial_state = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
h_ref, final_state_ref, V_new_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value,
)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(
chunk_gated_delta_rule_fwd_h,
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value,
)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
try:
h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗")
print(e)
try:
final_state_ref_fp32 = final_state_ref.to(torch.float32)
final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32)
assert_similar(
final_state_ref_fp32,
final_state_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False,
)
print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗")
print(e)
try:
V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗")
print(e)
print(f"tilelang time: {tilelang_time} ms")
print(f"fla time: {fla_time} ms")
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
use_g=True,
use_initial_state=False,
store_final_state=True,
save_new_value=True,
block_DK=32,
block_DV=32,
threads=128,
num_stages=2,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
BS = chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
return Q, K, V, HIDDEN, G
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
):
O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return O
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_fwd_o(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
H_shape = (B, S // BS, H, DK, DV)
G_shape = (B, S, H)
O_shape = (B, S, H, DV)
@T.prim_func
def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout(
{
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.clear(A_fragment)
T.clear(O_fragment)
T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
# T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s])
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared)
T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
use_g,
block_DK,
block_DV,
threads,
num_stages,
):
input_dtype_torch = getattr(torch, input_dtype)
output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(
B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch
)
scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size)
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G)
try:
torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk fwd o passed √")
except Exception as e:
print("tilelang chunk fwd o failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
use_g=True,
block_DK=128,
block_DV=128,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import math
import sys # noqa: F401
import tilelang
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
from test_utils import assert_similar
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
block_DK,
):
assert DK == 32 and block_DK == 32 or DK > 32 and block_DK >= 64, "When DK > 32, block_DK must be >= 64"
NK = math.ceil(DK / block_DK)
dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dw = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dg = torch.empty(NK, B, S, H, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_o_bwd3.log", "r").read()
# code = cuda_code
# return code
@tilelang.jit(
out_idx=[-4, -3, -2, -1],
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def tilelang_chunk_o_bwd_dqkwg(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
NK = math.ceil(DK / block_DK)
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
h_shape = (B, BS, H, DK, DV)
G_shape = (B, S, H)
dO_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dv_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
dq_shape = (B, S, H, DK)
dk_shape = (B, S, H, DK)
dw_shape = (B, S, H, DK)
dg_shape = (NK, B, S, H)
@T.prim_func
def kernel(
# input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dh: T.Tensor(dh_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
# output
dq: T.Tensor(dq_shape, dtype=output_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
k_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
ds_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
dg_shared_1 = T.alloc_shared((block_S,), dtype=gate_dtype)
dg_shared_2 = T.alloc_shared((block_S,), dtype=gate_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=accum_dtype)
ds_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive_transpose = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_2 = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
q_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
k_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_last_local = T.alloc_local((2,), dtype=gate_dtype)
dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype)
dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared")
G_last_local = T.alloc_local((1,), dtype=gate_dtype)
T.use_swizzle(10)
T.annotate_layout(
{
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dh_shared: tilelang.layout.make_swizzled_layout(dh_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
}
)
T.clear(dg_last_local)
T.clear(G_last_local)
T.clear(G_shared)
T.clear(q_fragment)
T.clear(k_fragment)
T.clear(dg_last_fragment)
T.clear(ds_fragment)
T.clear(dq_fragment)
T.clear(dk_fragment)
T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared)
T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared)
T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared)
if use_g:
T.clear(dg_last_fragment_scalar)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True)
T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True)
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw:
T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared)
T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment)
if use_g:
T.clear(dg_fragment)
T.clear(dg_fragment_2)
for i_s, i_k in T.Parallel(block_S, block_DK):
G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh]
G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh]
# Use gmem directly instead of local register
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else():
dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k])
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
T.copy(dk_fragment, dk_shared)
T.clear(dg_last_fragment_scalar_2)
for i_sk in T.Parallel(block_S * block_DK):
i_s, i_k = i_sk // block_DK, i_sk % block_DK
dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k]
T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False)
dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
ds_fragment[i_s1, i_s2] = (
ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale
)
with T.Else():
ds_fragment[i_s1, i_s2] = 0
T.clear(ds_fragment_positive)
T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
T.copy(dg_fragment, dg_shared_1)
# We should transpose the matrix because the reduce_sum statement can only reduce along the last dimension
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive_transpose[i_s2, i_s1] = ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive_transpose, dg_fragment_2, dim=1, clear=False)
T.copy(dg_fragment_2, dg_shared_2)
for i_s in T.Parallel(block_S):
dg_fragment_final[i_s] = dg_shared_1[i_s] - dg_shared_2[i_s]
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True)
for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then():
dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
ds_fragment[i_s1, i_s2] = 0
T.clear(dk_fragment_2)
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment_2, transpose_A=True)
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
# ref
if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_dw,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
# check
try:
assert_similar(dq_ref, dq_tilelang, 1e-5, "tilelang chunk o bwd dq")
print("tilelang chunk o bwd dq passed √")
except Exception as e:
print("tilelang chunk o bwd dq failed ✗")
print(e)
try:
assert_similar(dk_ref, dk_tilelang, 1e-5, "tilelang chunk o bwd dk")
print("tilelang chunk o bwd dk passed √")
except Exception as e:
print("tilelang chunk o bwd dk failed ✗")
print(e)
if use_g:
try:
assert_similar(dg_ref, dg_tilelang, 1e-5, "tilelang chunk o bwd dg")
print("tilelang chunk o bwd dg passed √")
except Exception as e:
print("tilelang chunk o bwd dg failed ✗")
print(e)
if use_dw:
try:
assert_similar(dw_ref, dw_tilelang, 1e-5, "tilelang chunk o bwd dw")
print("tilelang chunk o bwd dw passed √")
except Exception as e:
print("tilelang chunk o bwd dw failed ✗")
print(e)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
scale=DK**-0.5,
# scale=1,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_scaled_dot_kkt.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.set_printoptions(profile="full")
torch.random.manual_seed(0)
def prepare_input(
B,
S,
H,
DK,
input_dtype,
output_dtype,
accum_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=accum_dtype).cuda()
return K, Beta, G
def prepare_output(
B,
S,
H,
chunk_size,
dtype,
):
BS = chunk_size
A = torch.empty(B, S, H, BS, dtype=dtype).cuda()
return A
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_scaled_dot_kkt_fwd(
# task config
B,
S,
H,
DK,
chunk_size=64,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
use_g=True,
# kernel config
block_S=64,
block_DK=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
output_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
# !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
# Tensor used for gated:
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}
)
T.fill(A_fragment, 0)
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel
def run_test(
B,
S,
H,
DK,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
use_g,
block_DK,
threads,
num_stages,
):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference
if use_g:
A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else:
A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(
B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
)
A_tilelang = kernel(K, Beta, G)
try:
torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk scaled dot kkt fwd passed √")
except Exception as e:
print("tilelang chunk scaled dot kkt fwd failed ✗")
print(e)
print("reference cuda kernel:")
print(kernel.get_kernel_source())
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
chunk_size=64,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
use_g=True,
block_DK=64,
threads=128,
num_stages=2,
)
if __name__ == "__main__":
main()
# Util functions for flash linear attention cumsum
# Reference: fla/ops/utils/cumsum.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
@tilelang.jit(
out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
)
def tilelang_chunk_local_cumsum_scalar(
# task config
B,
S,
H,
chunk_size=64,
is_varlen=False,
head_first=False,
reverse=False,
input_dtype=T.float16,
output_dtype=T.float32,
# kernel config
block_S=64,
threads=256,
use_fragment=False,
):
G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func
def kernel(
G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first:
T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared)
else:
T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared)
if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else:
T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
else:
T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else:
T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
return kernel
def prepare_cumsum_input(
B,
S,
H,
dtype,
):
G = torch.randn(B, S, H, dtype=dtype).cuda()
return G
def prepare_cumsum_output(
B,
S,
H,
dtype,
):
G_new = torch.empty(B, S, H, dtype=dtype).cuda()
return G_new
def run_test(
B,
S,
H,
chunk_size,
reverse,
head_first,
input_dtype,
output_dtype,
threads,
use_fragment,
):
G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype))
G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
# reference cumsum
G_new_ref = chunk_local_cumsum_scalar(
g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype)
)
# tilelang cumsum
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
input_dtype=input_dtype,
output_dtype=output_dtype,
block_S=block_S,
threads=threads,
use_fragment=use_fragment,
)
torch.cuda.profiler.start()
G_new_tilelang = kernel(G)
torch.cuda.profiler.stop()
try:
torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2)
print("tilelang cumsum passed √")
except Exception as e:
print("tilelang cumsum failed ✗")
print(e)
print("G:")
print(G.view(-1))
print("G_new_tilelang:")
print(G_new_tilelang.view(-1))
print("G_new_ref:")
print(G_new_ref.view(-1))
def main():
run_test(
B=1,
S=32768,
H=32,
chunk_size=64,
reverse=True,
head_first=False,
input_dtype=T.float32,
output_dtype=T.float32,
threads=256,
use_fragment=False,
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment