Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_qk]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = "float16"
accum_dtype = "float"
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk,
by, :])
with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
T.annotate_layout({
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
})
T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk,
by, :])
T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk,
by, :])
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add(
dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add(
dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
return padding_mask
@tilelang.jit(
out_idx=[5, 6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch,
total_q,
total_kv,
heads,
max_seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
o_shape = [total_q, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
for i, d in T.Parallel(block_M, dim_qk):
if bx * block_M + i < q_current_seqlen:
Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d]
else:
Q_shared[i, d] = 0.0
T.fill(acc_o, 0.0)
T.fill(logsum, 0.0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
for i, d in T.Parallel(block_N, dim_qk):
if k * block_N + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + k * block_N + i, by // groups, d]
else:
K_shared[i, d] = 0.0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen), 0,
-T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i < q_current_seqlen and
k * block_N + j < k_current_seqlen, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k * block_N + i < k_current_seqlen:
V_shared[i, d] = V[k_start_idx + k * block_N + i, by // groups, d]
else:
V_shared[i, d] = 0.0
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
for i, d in T.Parallel(block_M, dim_v):
if bx * block_M + i < q_current_seqlen:
Output[q_start_idx + bx * block_M + i, by, d] = acc_o[i, d]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
if bx * block_M + i < q_current_seqlen:
lse[q_start_idx + bx * block_M + i, by] = logsum[i]
return flash_fwd
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, total_q, max_seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [total_q, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
q_end_idx = cu_seqlens_q[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
for i, j in T.Parallel(blk, blk):
if by * blk + i < q_current_seqlen and k * blk + j < dim_v:
o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j]
do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j]
else:
o[i, j] = 0.0
do[i, j] = 0.0
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
for i in T.Parallel(blk):
if by * blk + i < q_current_seqlen:
Delta[q_start_idx + by * blk + i, bx] = delta[i]
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dtype = "float16"
accum_dtype = "float"
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by):
# T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by):
# T.annotate_layout({
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
# })
T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :])
T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :])
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_atomic_add(batch,
total_q,
total_kv,
heads,
max_seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
do_shape = [total_q, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(
heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({
# dQ: make_dq_layout(dQ),
# dK: make_dq_layout(dK),
# dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d]
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d]
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.clear(dv)
T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0)
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d]
else:
q[i, d] = 0.0
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and
(by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen),
qkT[i, j], 0)
else:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
for i, d in T.Parallel(block_N, dim_v):
if k_base * block_N + i < q_current_seqlen:
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d]
else:
do[i, d] = 0.0
T.clear(dsT)
# dsT: (block_kv, block_q)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
bx, :],
dq,
memory_order="release")
T.atomic_add(
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dv,
memory_order="release")
T.atomic_add(
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :],
dk,
memory_order="release")
return flash_bwd
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
total_q,
total_kv,
heads,
max_seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
do_shape = [total_q, heads, dim_v]
dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([total_q, heads], accum_dtype), # type: ignore
Delta: T.Tensor([total_q, heads], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(
heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({
# dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + by * block_M + i, bx // groups, d]
V_shared[i, d] = V[k_start_idx + by * block_M + i, bx // groups, d]
else:
K_shared[i, d] = 0.0
V_shared[i, d] = 0.0
T.clear(dv)
T.clear(dk)
loop_st = (T.floordiv(by * block_M, block_N) if is_causal else 0)
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for i, d in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
q[i, d] = Q[q_start_idx + k_base * block_N + i, bx, d]
else:
q[i, d] = 0.0
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k_base * block_N + i < q_current_seqlen:
do[i, d] = dO[q_start_idx + k_base * block_N + i, bx, d]
else:
do[i, d] = 0.0
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
lse_shared[i] = lse[q_start_idx + k_base * block_N + i, bx]
else:
lse_shared[i] = 0.0
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and
(by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen),
qkT[i, j], 0)
else:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
for i in T.Parallel(block_N):
if k_base * block_N + i < q_current_seqlen:
delta[i] = Delta[q_start_idx + k_base * block_N + i, bx]
else:
delta[i] = 0.0
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j],
memory_order="release")
T.copy(dv, dv_shared)
for i, d in T.Parallel(block_M, dim_v):
if by * block_M + i < k_current_seqlen:
dV[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dv[i, d]
T.copy(dk, dk_shared)
for i, d in T.Parallel(block_M, dim_qk):
if by * block_M + i < k_current_seqlen:
dK[bx % groups, k_start_idx + by * block_M + i, bx // groups, d] = dk[i, d]
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx,
q,
k,
v,
seqlens_q,
seqlens_k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal,
groups=1,
use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
q_unpad, indices_q, _, _ = unpad_input(
q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
k_unpad, indices_k, _, _ = unpad_input(
k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
v_unpad, _, _, _ = unpad_input(
v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0]
mod = flashattn_fwd(BATCH, total_q, total_kv, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal,
block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k,
cu_seqlens_q, cu_seqlens_k)
ctx.causal = causal
ctx.use_atomic = use_atomic
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.indices_q = indices_q
ctx.indices_k = indices_k
return o
@staticmethod
def backward(ctx, do):
N_CTX = do.shape[1]
q, k, v, o, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
do_unpad, _, _, _ = unpad_input(
do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape
total_kv, HEAD_KV, D_HEAD_V = v.shape
groups = H // HEAD_KV
BATCH = len(cu_seqlens_q) - 1
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, ctx.max_seqlen_q, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
total_q,
total_kv,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
total_q,
total_kv,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX)
dk = pad_input(dk, ctx.indices_k, BATCH, N_CTX)
dv = pad_input(dv, ctx.indices_k, BATCH, N_CTX)
return dq, dk, dv, None, None, None, None, None, None, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
# To handle precision issue
Q, K, V = Q.float(), K.float(), V.float()
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if padding_mask is not None:
scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf"))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
if padding_mask is not None:
output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random")
seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32)
cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
# In training backward pass, seqlens_k should be the same as seqlens_q
seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q
O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
max_seqlen_k, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, padding_mask, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use split
use_atomic = False
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch,
heads,
seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.wait_wgmma(0)
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
delta = mod_prep(o, do)
kernel = flashattn_bwd(
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = dq.to(torch.float16)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups
K = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K')
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
class FlashAttentionTuneSpace:
def __init__(
self,
block_sizes=(64, 128, 256),
thread_options=(128, 256, 512),
num_stages_range=(2, 3),
max_shared_mem=100 * 1024,
warp_alignment=16,
dim=128,
dtype_bytes=2,
):
self.block_sizes = block_sizes
self.thread_options = thread_options
self.num_stages_range = num_stages_range
self.max_shared_mem = max_shared_mem
self.warp_alignment = warp_alignment
self.dim = dim
self.dtype_bytes = dtype_bytes
def get_configs(user_config=None):
config = user_config or FlashAttentionTuneSpace()
valid_configs = []
for block_M, block_N in itertools.product(config.block_sizes, repeat=2):
for threads in config.thread_options:
assert threads % 32 == 0
warp_count = threads // 32
warp_M = block_M // warp_count
warp_N = block_N // warp_count
if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0):
continue
shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N)
if shared_mem > config.max_shared_mem:
continue
for num_stages in config.num_stages_range:
valid_configs.append({
"block_M": block_M,
"block_N": block_N,
"num_stages": num_stages,
"threads": threads,
})
return valid_configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(batch: int = 1,
heads: int = 64,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 16,
tune: bool = False):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(
block_M=[128],
block_N=[128],
num_stages=[2],
threads=[256],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=10,
rep=10,
)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128,
):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 64,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 16,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument('--groups', type=int, default=16, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
# ruff: noqa
import argparse
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
from einops import rearrange, repeat
from tilelang.profiler import do_bench
from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1),
upcast=True,
):
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim)**0.5
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
scores = scores * scale
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch_size,
groups,
UQ,
UKV,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [UQ, heads, dim]
kv_shape = [UKV, head_kv, dim]
o_shape = [UQ, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(kv_shape, dtype),
V_unpad: T.Tensor(kv_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
with T.Kernel(
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
batch_idx = bz
head_idx = by
kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
T.copy(
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i >= q_current_seqlen:
Q_shared[i, d] = 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= k_current_seqlen:
K_shared[i, d] = 0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared)
for i, d in T.Parallel(block_N, dim):
if k * block_N + i >= v_current_seqlen:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
return main
def main(batch: int = 1,
heads: int = 64,
q_seqlen: int = 2048,
k_seqlen: int = 2048,
dim: int = 128,
groups: int = 16,
is_causal: bool = False):
assert heads % groups == 0, "heads must be divisible by groups"
flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim
total_flops = 2 * flops_per_matmul
tilelang.testing.set_random_seed(0)
causal = False
if causal:
total_flops *= 0.5
tilelang.testing.set_random_seed(0)
dtype = torch.float16
device = torch.device("cuda")
head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
_,
_,
) = generate_qkv(
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0]
UKV = k_unpad.shape[0]
kernel = flashattn(
batch,
groups,
UQ,
UKV,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask=query_padding_mask,
key_padding_mask=key_padding_mask,
causal=is_causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='query heads')
parser.add_argument('--groups', type=int, default=16, help='groups')
parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length')
parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=128, help='head dim')
parser.add_argument('--is_causal', action='store_true', help='causal attention')
args = parser.parse_args()
main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups,
args.is_causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64
block_N = 64 if D_HEAD <= 64 else 32
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, H, N_CTX, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64
block_N = 64 if D_HEAD <= 64 else 32
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, H, N_CTX, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(2)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
import argparse
@tilelang.jit(
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
dq_shared = T.alloc_shared([block_N, dim], accum_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = dq.to(torch.float16)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = (
torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=1, help='heads')
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(
batch: int = 8,
heads: int = 32,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def main(
batch: int = 8,
heads: int = 32,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
# ruff: noqa
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
import argparse
import torch
from einops import rearrange, repeat
from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim)**0.5 # log2(e)
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0)
scores = scores * scale
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch_size,
UQ,
UKV,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=32):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [UQ, heads, dim]
k_shape = [UKV, heads, dim]
v_shape = [UKV, heads, dim]
o_shape = [UQ, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
with T.Kernel(
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype, "shared")
K_shared = T.alloc_shared([block_N, dim], dtype, "shared")
V_shared = T.alloc_shared([block_N, dim], dtype, "shared")
O_shared = T.alloc_shared([block_M, dim], dtype, "shared")
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
batch_idx = bz
head_idx = by
q_start_idx = cu_seqlens_q[batch_idx]
k_start_idx = cu_seqlens_k[batch_idx]
v_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
v_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
v_current_seqlen = v_end_idx - v_start_idx
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Q_shared[i, d] = Q_unpad[q_start_idx + bx * block_M + i, head_idx, d]
else:
Q_shared[i, d] = 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Q * K
for i, d in T.Parallel(block_N, dim):
if k * block_N + i < k_current_seqlen:
K_shared[i, d] = K_unpad[k_start_idx + k * block_N + i, head_idx, d]
else:
K_shared[i, d] = 0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
for i, d in T.grid(block_N, dim):
if k * block_N + i < v_current_seqlen:
V_shared[i, d] = V_unpad[v_start_idx + k * block_N + i, head_idx, d]
else:
V_shared[i, d] = 0
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
return main
def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
tilelang.testing.set_random_seed(0)
causal = False
if causal:
total_flops *= 0.5
dtype = torch.float16
device = torch.device("cuda")
window_size = (-1, -1)
q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device)
query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0] # unpadded query length
UK = k_unpad.shape[0] # unpadded key length
UKV = k_unpad.shape[0] # unpadded query key length
kernel = flashattn(batch, UQ, UKV, heads, dim, causal)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
causal=causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
import flash_attn
fla_out_unpad = flash_attn.flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
)
fla_out = output_pad_fn(fla_out_unpad)
torch.testing.assert_close(out, fla_out, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_len', type=int, default=2048, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim)
import tilelang.testing
import example_gqa_bwd
import example_gqa_bwd_wgmma_pipelined
import example_mha_bwd
import example_mha_bwd_bhsd
import example_mha_fwd_bhsd_wgmma_pipelined
import example_gqa_fwd_bshd
import example_mha_fwd_bshd
import example_gqa_fwd_bshd_wgmma_pipelined
import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen
@tilelang.testing.requires_cuda
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()
@tilelang.testing.requires_cuda
def test_example_gqa_bwd():
example_gqa_bwd.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_bwd_wgmma_pipelined():
example_gqa_bwd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd.main(BATCH=1)
@tilelang.testing.requires_cuda
def test_example_mha_bwd_bhsd():
example_mha_bwd_bhsd.main(BATCH=1)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_fwd_bhsd_wgmma_pipelined():
example_mha_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
def test_example_mha_fwd_bhsd():
example_mha_fwd_bhsd.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_fwd_bshd_wgmma_pipelined():
example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256)
@tilelang.testing.requires_cuda
def test_example_mha_fwd_bshd():
example_mha_fwd_bshd.main(batch=1, seq_len=256)
@tilelang.testing.requires_cuda
def test_example_mha_fwd_varlen():
example_mha_fwd_varlen.main()
if __name__ == "__main__":
tilelang.testing.main()
# ruff: noqa
import torch
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
return padding_mask
def generate_qkv(q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import itertools
from functools import lru_cache
from typing import Tuple, Dict
torch.random.manual_seed(0)
def get_configs():
block_N = [64, 128]
block_H = [64]
num_split = [2, 4, 8]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs
@lru_cache(maxsize=1)
def get_heuristic_config() -> Tuple[Dict, int]:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128)
else:
cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128)
return cfg, sm_version
# TODO(lei): fix warp specialized and tma lower pass
def get_pass_configs():
return {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
}
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages,
threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim]
shape_o = [batch, heads, dim]
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // groups
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], K_shared)
T.copy(
mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i,
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn(Q, K, V, mask, Output)
if num_split > 1:
return flashattn_gqa_decode_split
else:
return flashattn_gqa_decode_no_split
def ref_program(query, key, value, mask, glse, Output_partial):
# """
# Inputs:
# - query (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, groups, dim]
# - value (Tensor): [batch, seqlen_kv, groups, dim]
# - mask (Tensor): [batch, seqlen_kv, groups]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = query.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
if mask is not None:
mask = rearrange(mask, 'b s h -> b h s')
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def flash_split_ref(Q, K, V, mask):
num_split = 16
batch = Q.size(0)
nheads = Q.size(1)
groups = K.size(2)
dim = Q.size(-1)
block_N = 32
seqlen_kv = K.size(1)
num_head_groups = nheads // groups
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N),
device="cuda",
dtype=torch.float16)
acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups),
device="cuda",
dtype=torch.float)
scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bghd,bkhd->bghk', Q_,
K[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
if mask is not None:
mask_local = mask[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
mask_local = rearrange(mask_local, 'b s h -> b h s')
mask_local = mask_local.unsqueeze(1)
acc_s = acc_s.masked_fill(mask_local == 0, float('-inf'))
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
acc_o *= scores_scale[:, :, :, None]
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
'bghk,bkhd->bghd', acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d')
logsum_out = rearrange(logsum, 'b g h->b (h g)')
acc_o_out /= logsum_out[:, :, None]
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)')
gacc_o[ks, :, :, :] = acc_o_out
glogsum[ks, :, :] = logsum_out
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
def reduce_ref(Q, K, V, mask, glse, Output_partial):
num_split = 16
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) # [batch, heads]
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks]
scale = torch.exp2(lse - lse_logsum) # [batch, heads]
o += Output_partial[:, :, ks, :] * scale[:, :, None]
return o.to(torch.float16)
def ref_split_program(Q, K, V, mask, glse=None, Output_partial=None):
glse_, Output_partial_ = flash_split_ref(Q, K, V, mask)
return reduce_ref(Q, K, V, mask, glse_, Output_partial_)
def print_red_warning(msg):
print(f"\033[91m{msg}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if assert_:
raise AssertionError(f'{name} Error: {diff}')
else:
if print_:
print(f'passed: {name} diff={diff}')
def main(batch: int = 1,
heads: int = 32,
groups: int = 8,
kv_seqlen: int = 8192,
dim: int = 128,
tune: bool = False):
batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops
if (not tune):
config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16)
mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8)
glse = torch.empty(batch, heads, 16, device="cuda", dtype=torch.float16)
Output_partial = torch.empty(batch, heads, 16, dim, device="cuda", dtype=torch.float16)
o = kernel(q, k, v, mask, glse, Output_partial)
o_ref = ref_program(q, k, v, mask, glse, Output_partial)
o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial)
print(o)
print(o_ref)
assert_similar(o, o_ref, name="o_ref")
assert_similar(o_ref_split, o_ref, name="o_ref_split")
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from functools import partial
num_split = 4
@tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(shape_kv, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
mid: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared)
# TODO: Handle causal split case
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape_kv, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
hid: T.int32,
bid: T.int32,
sid: T.int32,
):
T.copy(
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
mid = bx
hid = by % heads
bid = by // heads
sid = bz
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle causal split case
loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
(mid + 1) * block_M, block_N)) if is_causal else T.ceildiv(
(seqlen_kv // num_split), block_N))
for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
T.copy(acc_o, O_shared)
T.copy(
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_q, dtype),
):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype)
po_shared = T.alloc_shared([block_M, dim], dtype)
o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype)
o_shared = T.alloc_shared([block_M, dim], dtype)
lse_local = T.alloc_fragment([num_split, block_M], dtype)
lse_local_split = T.alloc_fragment([block_M], accum_dtype)
lse_logsum_local = T.alloc_fragment([block_M], accum_dtype)
lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
o_shared: tilelang.layout.make_swizzled_layout(o_shared),
po_shared: tilelang.layout.make_swizzled_layout(po_shared),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
T.copy(glse[
bz,
by,
:,
bx * block_M:(bx + 1) * block_M,
], lse_local)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split)
for i in T.Parallel(block_M):
lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i])
for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2):
T.copy(
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(po_shared, po_local)
for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
for i in T.Parallel(block_M):
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func
def flashattn_mha_inference(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Tensor(shape_q, dtype),
):
flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output)
return flashattn_mha_inference
def ref_program(Q, K, V, glse, Output_partial, causal):
assert causal is False
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def reduce_ref(Q, K, V, glse, Output_partial, causal):
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks, :]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks, :]
scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q]
o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2)
return o.to(torch.float16)
def flash_split_ref(Q, K, V, causal):
# [batch, seqlen_q, heads, dim]
batch = Q.size(0)
block_M = Q.size(1)
nheads = Q.size(2)
dim = Q.size(3)
block_N = 128
seqlen_kv = K.size(1)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float)
Q_ = Q * scale
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_,
K[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max)
acc_o *= scores_scale[:, :, :, None].transpose(1, 2)
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum(
'bhqk,bkhd->bqhd', acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2)
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0,
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main():
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
causal = False
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
BLOCK_M = 128
BLOCK_N = 64 # if D_HEAD <= 128 else 32
kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N)
ref_fn = partial(ref_program, causal=causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01)
print("All checks passed!")
latency = profiler.do_bench(ref_fn, warmup=500)
print("{:.2f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(n_warmup=10, n_repeat=10)
print("{:.4f} ms".format(latency))
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment