Commit 2add9fa3 authored by wangkx1's avatar wangkx1
Browse files

add tilelang

parent f5bc26c2
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
else:
sm_scale = sm_scale
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
head_kv = heads // kv_group
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len, kv_group, dim + tail_dim]
o_shape = [seq_len, heads, dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, Output[bos + s_i, H0:H1, :])
T.copy(sumexp, Lse[bos + s_i, H0:H1])
return main
def sparse_mla_fwd_interface(
q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128
):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
seq_len, heads, dim_plus_tail_dim = q.shape
seq_len_kv, kv_group, _ = kv.shape
assert seq_len == seq_len_kv
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
_, _, topk = indices.shape
assert indices.shape == (seq_len, kv_group, topk)
token_indices = prepare_token_indices(offsets)
kernel = sparse_mla_fwd(
heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
)
out, lse = kernel(q, kv, indices, offsets, token_indices)
return out, lse
def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True):
Q = Q.float()
KV = KV.float()
all_o = []
for i in range(offsets.shape[0] - 1):
q = Q[None, offsets[i] : offsets[i + 1]]
kv = KV[None, offsets[i] : offsets[i + 1]]
indices = Indices[None, offsets[i] : offsets[i + 1]].clone()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
).view(1, -1)
indices[indices > sk] = sk
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
all_o.append(o.squeeze(0))
o = torch.cat(all_o, dim=0)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
):
torch.random.manual_seed(0)
q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=1024,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
)
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
@T.prim_func
def tl_sparse_mla_topk_reducesum_kernel(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
reducesum = T.alloc_fragment([BI], accum_dtype)
lse = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(lse, 0)
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
r_i = bx % REPLICATE_H
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
T.copy(Lse[bos + s_i, H0:H1], lse)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
T.reduce_sum(acc_s, reducesum, dim=0)
T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI])
return tl_sparse_mla_topk_reducesum_kernel
def sparse_mla_topk_reducesum_interface(
q: torch.Tensor,
kv: torch.Tensor,
topk_indices: torch.Tensor,
lse: torch.Tensor,
offsets: torch.Tensor,
dim_v: int,
):
assert kv.shape[-2] == 1
seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
REPLICATE_H = max(heads // 64, 1)
tail_dim = dim_plus_tail_dim - dim_v
token_indices = prepare_token_indices(offsets)
reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)
return attn_score
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor):
# q: [batch, seq_len, heads, dim]
# k: [batch, seq_len, dim]
sm_scale = Q.shape[-1] ** -0.5
all_lse = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
q = Q[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
seq_len = q.shape[0]
mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
score = F.softmax(logits, dim=-1, dtype=torch.float32)
score_sum = score.sum(dim=-2)
topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
max_logits = logits.amax(dim=-1).to(torch.float32)
lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
all_lse.append(lse)
all_topk_score.append(topk_score)
lse = torch.cat(all_lse, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return lse, topk_score
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
topk=128,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)
kv = kv.unsqueeze(-2)
topk_indices = topk_indices.unsqueeze(-2)
attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}")
if __name__ == "__main__":
test_kernel()
import torch
def get_abs_err(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
return (x - y).flatten().abs().max().item()
def get_err_ratio(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
err = (x - y).flatten().square().mean().sqrt().item()
base = (x).flatten().square().mean().sqrt().item()
return err / base
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
if raise_assert:
assert False # noqa: B011
import argparse
import itertools
import torch
import tilelang
import tilelang.language as T
def ref_program(x, y):
return x + y
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
def main(M=1024, N=1024, use_autotune=False):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if use_autotune:
kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32)
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
main(args.m, args.n, args.use_autotune)
import tilelang.testing
import example_elementwise_add
def test_example_elementwise_add():
example_elementwise_add.main()
def test_example_elementwise_add_autotune():
example_elementwise_add.main(use_autotune=True)
if __name__ == "__main__":
tilelang.testing.main()
# FlashAttention
Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way.
```python
@T.prim_func
def flash_attention(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
# Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
# bx: block index in sequence dimension
# by: block index in "heads" dimension
# bz: block index in "batch" dimension
# threads=thread_num means how many threads per block
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
# Allocate shared memory for Q, K, V to reduce global memory accesses
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Allocate buffers on register
# acc_s: buffer to hold intermediate attention scores
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
# acc_s_cast: buffer for storing casted/adjusted scores
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
# acc_o: partial accumulation of output
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
# Buffers to track per-row maximum score and related stats
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
# Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
# Copy a block of Q from global memory to Q_shared
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
# Initialize accumulators
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
)
# Pipeline the loop to overlap copies/gemm stages
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Copy K block into shared memory
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
)
else:
T.clear(acc_s)
# Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed,
# policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row
# of acc_s, and the resulting acc_s is retained in registers.
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Copy V block into shared memory
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
for i, j in T.Parallel(block_M, dim):
acc_s[i, j] *= scale
# Save old scores_max, then reset scores_max
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
# Rescale the partial output accumulation to keep exponents consistent
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# Exponentiate (scores - max) for the new block
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
# Make a cast of acc_s to fp16 for the next GEMM
T.copy(acc_s, acc_s_cast)
# Multiply the attention acc_s_cast by V and add to partial output (acc_o)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
# Update the "logsum" tracker with the newly accumulated sum
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
# Final step: divide each partial output by logsum (completing the softmax)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
# Write back the final output block from acc_o to the Output buffer
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
```
\ No newline at end of file
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# ruff: noqa
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_qk]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
(
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
tilelang.disable_cache()
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = T.float16
accum_dtype = T.float32
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
T.annotate_layout(
{
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
}
)
T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :])
T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :])
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
return flash_bwd
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
(
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split_novarlen(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
return padding_mask
@tilelang.jit(
out_idx=[5, 6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
o_shape = [total_q, heads, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
for i, d in T.Parallel(block_M, dim_qk):
if bx * block_M + i < q_current_seqlen:
Q_shared[i, d] = Q[q_start_idx + bx * block_M + i, by, d]
else:
Q_shared[i, d] = 0.0
T.fill(acc_o, 0.0)
T.fill(logsum, 0.0)
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = T.ceildiv(k_current_seqlen, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
for i, d in T.Parallel(block_N, dim_qk):
if k * block_N + i < k_current_seqlen:
K_shared[i, d] = K[k_start_idx + k * block_N + i, by // groups, d]
else:
K_shared[i, d] = 0.0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= k * block_N + j)
and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen),
0,
T.Cast(accum_dtype, -1e30),
)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)
)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v):
if k * block_N + i < k_current_seqlen:
V_shared[i, d] = V[k_start_idx + k * block_N + i, by // groups, d]
else:
V_shared[i, d] = 0.0
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
for i, d in T.Parallel(block_M, dim_v):
if bx * block_M + i < q_current_seqlen:
Output[q_start_idx + bx * block_M + i, by, d] = acc_o[i, d]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
if bx * block_M + i < q_current_seqlen:
lse[bz, by, bx * block_M + i] = logsum[i]
return flash_fwd
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [total_q, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
q_end_idx = cu_seqlens_q[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
for i, j in T.Parallel(blk, blk):
if by * blk + i < q_current_seqlen and k * blk + j < dim_v:
o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j]
do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j]
else:
o[i, j] = 0.0
do[i, j] = 0.0
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
for i in T.Parallel(blk):
if by * blk + i < q_current_seqlen:
Delta[bz, bx, by * blk + i] = delta[i]
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhsd to use tma reduction instruction
return T.Layout(dQ.shape, lambda l, h, d: [h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dtype = T.float16
accum_dtype = T.float32
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by):
T.annotate_layout(
{
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
}
)
T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :])
T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :])
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_atomic_add(
batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1
):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
do_shape = [total_q, heads, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
(by * block_M + i <= k_base * block_N + j)
and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen),
qkT[i, j],
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0
)
T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do)
T.clear(dsT)
# dsT: (block_kv, block_q)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(
dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :],
dq_shared,
memory_order="relaxed",
use_tma=True,
)
T.copy(dv, dv_shared)
T.atomic_add(
dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
dv_shared,
memory_order="relaxed",
use_tma=True,
)
T.copy(dk, dk_shared)
T.atomic_add(
dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
dk_shared,
memory_order="relaxed",
use_tma=True,
)
return flash_bwd
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_split(
batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1
):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
do_shape = [total_q, heads, dim_v]
dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
q_start_idx = cu_seqlens_q[bz]
k_start_idx = cu_seqlens_k[bz]
q_end_idx = cu_seqlens_q[bz + 1]
k_end_idx = cu_seqlens_k[bz + 1]
q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
# Note: The padding zero of varlen should be considered in T.copy
T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
(by * block_M + i <= k_base * block_N + j)
and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen),
qkT[i, j],
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0
)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed")
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True
):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0]
mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH
ctx.causal = causal
ctx.use_atomic = use_atomic
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.indices_q = indices_q
ctx.indices_k = indices_k
return o
@staticmethod
def backward(ctx, do):
N_CTX = do.shape[1]
q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
# lse_clone = lse.clone()
do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape
total_kv, HEAD_KV, D_HEAD_V = v.shape
groups = H // HEAD_KV
BATCH = len(cu_seqlens_q) - 1
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do_unpad, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do, cu_seqlens_q)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH,
total_q,
total_kv,
N_CTX,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups,
)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split(
BATCH,
total_q,
total_kv,
N_CTX,
H,
ctx.max_seqlen_q,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups,
)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX)
dk = pad_input(dk, ctx.indices_k, BATCH, N_CTX)
dv = pad_input(dv, ctx.indices_k, BATCH, N_CTX)
return dq, dk, dv, None, None, None, None, None, None, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
# To handle precision issue
Q, K, V = Q.float(), K.float(), V.float()
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if padding_mask is not None:
scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf"))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
if padding_mask is not None:
output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
return output
def main(
BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random")
seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32)
cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
# In training backward pass, seqlens_k should be the same as seqlens_q
seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q
O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, padding_mask, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
print(
"Note: this varlen kernel performance is as good as the non-varlen kernel shown in Nsight-Compute. As you may observe that the TFLOPS is a bit lower, that's because the unpad operation is included in the above benchmark."
)
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
args = parser.parse_args()
# Can be set to True/False for testing
args.causal = True
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.wait_wgmma(1)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.wait_wgmma(0)
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
(
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
delta = mod_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = dq.to(torch.float16)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
O = attention(Q, K, V, causal, groups)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
class FlashAttentionTuneSpace:
def __init__(
self,
block_sizes=(64, 128, 256),
thread_options=(128, 256, 512),
num_stages_range=(2, 3),
max_shared_mem=100 * 1024,
warp_alignment=16,
dim=128,
dtype_bytes=2,
):
self.block_sizes = block_sizes
self.thread_options = thread_options
self.num_stages_range = num_stages_range
self.max_shared_mem = max_shared_mem
self.warp_alignment = warp_alignment
self.dim = dim
self.dtype_bytes = dtype_bytes
def get_configs(user_config=None):
config = user_config or FlashAttentionTuneSpace()
valid_configs = []
for block_M, block_N in itertools.product(config.block_sizes, repeat=2):
for threads in config.thread_options:
assert threads % 32 == 0
warp_count = threads // 32
warp_M = block_M // warp_count
warp_N = block_N // warp_count
if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0:
continue
shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N)
if shared_mem > config.max_shared_mem:
continue
for num_stages in config.num_stages_range:
valid_configs.append(
{
"block_M": block_M,
"block_N": block_N,
"num_stages": num_stages,
"threads": threads,
}
)
return valid_configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(
block_M=[128],
block_N=[128],
num_stages=[2],
threads=[256],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=10,
rep=10,
)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128,
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 64,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 16,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
# ruff: noqa
import argparse
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
from einops import rearrange, repeat
from tilelang.profiler import do_bench
from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1),
upcast=True,
):
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
b, T, Hq, D = q.shape
S = k.shape[1]
scale = (1.0 / D) ** 0.5
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
left, right = window_size
left = S if left is None or left < 0 else int(left)
right = S if right is None or right < 0 else int(right)
t_idx = torch.arange(T, device=scores.device)[:, None]
s_idx = torch.arange(S, device=scores.device)[None, :]
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
if key_padding_mask is not None:
k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
visible_mask = visible_mask & k_keep
neg_inf = torch.finfo(scores.dtype).min
scores = scores * scale
scores = scores.masked_fill(~visible_mask, neg_inf)
attention = torch.softmax(scores, dim=-1).to(v.dtype)
if query_padding_mask is not None:
q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
attention = attention.masked_fill(~q_keep, 0.0)
output = torch.einsum("bhts,bshd->bthd", attention, v)
if query_padding_mask is not None:
output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [UQ, heads, dim]
kv_shape = [UKV, head_kv, dim]
o_shape = [UQ, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(kv_shape, dtype),
V_unpad: T.Tensor(kv_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], T.int32),
cu_seqlens_k: T.Tensor([batch_size + 1], T.int32),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
}
)
batch_idx = bz
head_idx = by
kv_head_idx = head_idx // groups
q_start_idx = cu_seqlens_q[batch_idx]
kv_start_idx = cu_seqlens_k[batch_idx]
q_end_idx = cu_seqlens_q[batch_idx + 1]
k_end_idx = cu_seqlens_k[batch_idx + 1]
q_current_seqlen = q_end_idx - q_start_idx
kv_current_seqlen = k_end_idx - kv_start_idx
T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
if is_causal
else T.ceildiv(kv_current_seqlen, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i < k * block_N + j)
or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen),
-1e9,
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0
)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
for i, d in T.Parallel(block_M, dim):
if bx * block_M + i < q_current_seqlen:
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
return main
def main(
batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False
):
assert heads % groups == 0, "heads must be divisible by groups"
flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim
total_flops = 2 * flops_per_matmul
tilelang.testing.set_random_seed(0)
if is_causal:
total_flops *= 0.5
tilelang.testing.set_random_seed(0)
dtype = torch.float16
device = torch.device("cuda")
head_kv = heads // groups
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
_,
_,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0]
UKV = k_unpad.shape[0]
kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad)
out_ref, _ = attention_ref(
q,
k,
v,
query_padding_mask=query_padding_mask,
key_padding_mask=key_padding_mask,
causal=is_causal,
)
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="query heads")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length")
parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length")
parser.add_argument("--dim", type=int, default=128, help="head dim")
parser.add_argument("--is_causal", action="store_true", help="causal attention")
args = parser.parse_args()
main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, H, N_CTX, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64
block_N = 64 if D_HEAD <= 64 else 32
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, H, N_CTX, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(2)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 64
block_N = 64 if D_HEAD <= 64 else 32
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype)
dq_shared = T.alloc_shared([block_N, dim], accum_dtype)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.wait_wgmma(1)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
mod(q, k, v, do, lse, delta, dq, dk, dv)
dq = dq.to(torch.float16)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 8,
H: int = 32,
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
):
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
if is_causal
else T.ceildiv(seq_kv, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=1, help="heads")
parser.add_argument("--seq_q", type=int, default=256, help="query sequence length")
parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal", default=False)
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
if is_causal
else T.ceildiv(seq_kv, block_N)
)
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length")
parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
batch: int = 8,
heads: int = 32,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if not tune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment