Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -27,15 +27,12 @@ def sparse_mla_fwd( ...@@ -27,15 +27,12 @@ def sparse_mla_fwd(
num_stages=2, num_stages=2,
threads=128, threads=128,
): ):
assert dim == tilelang.math.next_power_of_2( assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported" assert is_causal == True, "non-casual is not supported"
assert (topk % assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
else: else:
sm_scale = sm_scale sm_scale = sm_scale
...@@ -58,9 +55,9 @@ def sparse_mla_fwd( ...@@ -58,9 +55,9 @@ def sparse_mla_fwd(
H = head_kv H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H: if padded_H != H:
assert ( assert kv_group == 1, (
kv_group == 1 "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" )
BI = block_I BI = block_I
NI = tilelang.cdiv(topk, block_I) NI = tilelang.cdiv(topk, block_I)
D = dim D = dim
...@@ -76,19 +73,18 @@ def sparse_mla_fwd( ...@@ -76,19 +73,18 @@ def sparse_mla_fwd(
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
seq_len * REPLICATE_H, kv_group, threads=threads) as ( bx,
bx, by,
by, ):
):
Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype) KV_shared = T.alloc_shared([BI, D], dtype)
...@@ -122,17 +118,13 @@ def sparse_mla_fwd( ...@@ -122,17 +118,13 @@ def sparse_mla_fwd(
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared) T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages): for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI): for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D): for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail): for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
...@@ -177,16 +169,9 @@ def sparse_mla_fwd( ...@@ -177,16 +169,9 @@ def sparse_mla_fwd(
return main return main
def sparse_mla_fwd_interface(q, def sparse_mla_fwd_interface(
kv, q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128
indices, ):
offsets,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=32,
num_stages=2,
threads=128):
is_casual = True is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only" assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
...@@ -205,16 +190,8 @@ def sparse_mla_fwd_interface(q, ...@@ -205,16 +190,8 @@ def sparse_mla_fwd_interface(q,
token_indices = prepare_token_indices(offsets) token_indices = prepare_token_indices(offsets)
kernel = sparse_mla_fwd( kernel = sparse_mla_fwd(
heads, heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
dim, )
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
out, lse = kernel(q, kv, indices, offsets, token_indices) out, lse = kernel(q, kv, indices, offsets, token_indices)
return out, lse return out, lse
...@@ -224,9 +201,9 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu ...@@ -224,9 +201,9 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu
KV = KV.float() KV = KV.float()
all_o = [] all_o = []
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
q = Q[None, offsets[i]:offsets[i + 1]] q = Q[None, offsets[i] : offsets[i + 1]]
kv = KV[None, offsets[i]:offsets[i + 1]] kv = KV[None, offsets[i] : offsets[i + 1]]
indices = Indices[None, offsets[i]:offsets[i + 1]].clone() indices = Indices[None, offsets[i] : offsets[i + 1]].clone()
indices = indices.transpose(1, 2) indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape b, sq, h, dim_q = q.shape
...@@ -240,15 +217,15 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu ...@@ -240,15 +217,15 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu
b, _, _, dim_v = v.shape b, _, _, dim_v = v.shape
g_index = g g_index = g
h_index = h // g h_index = h // g
compressed_casual_mask = torch.arange( compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) ).view(1, -1)
indices[indices > sk] = sk indices[indices > sk] = sk
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1] mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk) mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk) mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q) q = q.view(b, sq, g, -1, dim_q)
...@@ -265,18 +242,20 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu ...@@ -265,18 +242,20 @@ def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casu
return o.to(torch.bfloat16) return o.to(torch.bfloat16)
def test_sparse_mla_fwd(B=1, def test_sparse_mla_fwd(
S=4096, B=1,
H=128, S=4096,
HKV=1, H=128,
DQK=576, HKV=1,
DV=512, DQK=576,
topk=2048, DV=512,
dtype=torch.bfloat16, topk=2048,
check_correctness=True, dtype=torch.bfloat16,
block_I=64, check_correctness=True,
num_stages=2, block_I=64,
threads=256): num_stages=2,
threads=256,
):
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
...@@ -289,10 +268,9 @@ def test_sparse_mla_fwd(B=1, ...@@ -289,10 +268,9 @@ def test_sparse_mla_fwd(B=1,
for t in range(seq_len): for t in range(seq_len):
for h in range(HKV): for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk] i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, :len(i_i)] = i_i indices[offsets[i] + t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface( tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness: if check_correctness:
# otherwise may cause out of memory # otherwise may cause out of memory
...@@ -301,8 +279,7 @@ def test_sparse_mla_fwd(B=1, ...@@ -301,8 +279,7 @@ def test_sparse_mla_fwd(B=1,
print("assert_tensors_similar passed") print("assert_tensors_similar passed")
def fn(): def fn():
return sparse_mla_fwd_interface( return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
...@@ -329,4 +306,5 @@ if __name__ == "__main__": ...@@ -329,4 +306,5 @@ if __name__ == "__main__":
check_correctness=True, check_correctness=True,
block_I=64, block_I=64,
num_stages=2, num_stages=2,
threads=256) threads=256,
)
...@@ -30,14 +30,11 @@ def tl_sparse_mla_topk_reducesum_impl( ...@@ -30,14 +30,11 @@ def tl_sparse_mla_topk_reducesum_impl(
num_stages=2, num_stages=2,
threads=128, threads=128,
): ):
assert dim == tilelang.math.next_power_of_2( assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert tail_dim == tilelang.math.next_power_of_2( assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
batch_plus_one = T.symbolic("batch_plus_one") batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len") seq_len = T.symbolic("seq_len")
...@@ -52,9 +49,9 @@ def tl_sparse_mla_topk_reducesum_impl( ...@@ -52,9 +49,9 @@ def tl_sparse_mla_topk_reducesum_impl(
H = head_kv H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H: if padded_H != H:
assert ( assert kv_group == 1, (
kv_group == 1 "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" )
BI = block_I BI = block_I
NI = tilelang.cdiv(topk, block_I) NI = tilelang.cdiv(topk, block_I)
D = dim D = dim
...@@ -78,19 +75,18 @@ def tl_sparse_mla_topk_reducesum_impl( ...@@ -78,19 +75,18 @@ def tl_sparse_mla_topk_reducesum_impl(
@T.prim_func @T.prim_func
def tl_sparse_mla_topk_reducesum_kernel( def tl_sparse_mla_topk_reducesum_kernel(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
seq_len * REPLICATE_H, kv_group, threads=threads) as ( bx,
bx, by,
by, ):
):
Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype) KV_shared = T.alloc_shared([BI, D], dtype)
...@@ -119,17 +115,13 @@ def tl_sparse_mla_topk_reducesum_impl( ...@@ -119,17 +115,13 @@ def tl_sparse_mla_topk_reducesum_impl(
T.copy(Lse[bos + s_i, H0:H1], lse) T.copy(Lse[bos + s_i, H0:H1], lse)
for i_i in T.Pipelined(NI, num_stages=num_stages): for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI): for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & ( mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D): for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail): for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
...@@ -150,7 +142,7 @@ def tl_sparse_mla_topk_reducesum_impl( ...@@ -150,7 +142,7 @@ def tl_sparse_mla_topk_reducesum_impl(
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i]) acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
T.reduce_sum(acc_s, reducesum, dim=0) T.reduce_sum(acc_s, reducesum, dim=0)
T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI]) T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI])
return tl_sparse_mla_topk_reducesum_kernel return tl_sparse_mla_topk_reducesum_kernel
...@@ -178,29 +170,26 @@ def sparse_mla_topk_reducesum_interface( ...@@ -178,29 +170,26 @@ def sparse_mla_topk_reducesum_interface(
return attn_score return attn_score
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor):
offsets: torch.Tensor):
# q: [batch, seq_len, heads, dim] # q: [batch, seq_len, heads, dim]
# k: [batch, seq_len, dim] # k: [batch, seq_len, dim]
sm_scale = Q.shape[-1]**-0.5 sm_scale = Q.shape[-1] ** -0.5
all_lse = [] all_lse = []
all_topk_score = [] all_topk_score = []
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
q = Q[offsets[i]:offsets[i + 1]] q = Q[offsets[i] : offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]] k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
seq_len = q.shape[0] seq_len = q.shape[0]
mask = (torch.arange(seq_len)[:, None] mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
>= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda() logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale
logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale logits = torch.where(mask, logits, float("-inf"))
logits = torch.where(mask, logits, float('-inf'))
score = F.softmax(logits, dim=-1, dtype=torch.float32) score = F.softmax(logits, dim=-1, dtype=torch.float32)
score_sum = score.sum(dim=-2) score_sum = score.sum(dim=-2)
topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64)) topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True) topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
max_logits = logits.amax(dim=-1).to(torch.float32) max_logits = logits.amax(dim=-1).to(torch.float32)
lse = torch.log( lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
(logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
all_lse.append(lse) all_lse.append(lse)
all_topk_score.append(topk_score) all_topk_score.append(topk_score)
lse = torch.cat(all_lse, dim=0) lse = torch.cat(all_lse, dim=0)
...@@ -222,20 +211,16 @@ def test_kernel( ...@@ -222,20 +211,16 @@ def test_kernel(
kv = torch.randn((S, D + tail_D)).cuda().bfloat16() kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda() offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
topk_indices = repeat( topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()
lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets) lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)
kv = kv.unsqueeze(-2) kv = kv.unsqueeze(-2)
topk_indices = topk_indices.unsqueeze(-2) topk_indices = topk_indices.unsqueeze(-2)
attn_score = sparse_mla_topk_reducesum_interface( attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2) print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}")
print(
f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}"
)
if __name__ == '__main__': if __name__ == "__main__":
test_kernel() test_kernel()
...@@ -66,10 +66,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): ...@@ -66,10 +66,8 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
raise_assert: Whether to raise assertion error on failure raise_assert: Whether to raise assertion error on failure
""" """
sim = calculate_tensor_similarity(x, y, name) sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print( print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert: if raise_assert:
assert False # noqa: B011 assert False # noqa: B011
...@@ -29,9 +29,9 @@ def matmul_dynamic_mnk( ...@@ -29,9 +29,9 @@ def matmul_dynamic_mnk(
@T.prim_func @T.prim_func
def dynamic_matmul( def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -53,15 +53,14 @@ def matmul_dynamic_mnk( ...@@ -53,15 +53,14 @@ def matmul_dynamic_mnk(
return dynamic_matmul return dynamic_matmul
def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads):
accum_dtype, num_stages, threads):
print( print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
) )
kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
accum_dtype, num_stages, threads)
import torch import torch
if trans_A: if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else: else:
...@@ -103,8 +102,7 @@ def main(M=16384, N=16384, K=16384): ...@@ -103,8 +102,7 @@ def main(M=16384, N=16384, K=16384):
accum_dtype = "float32" accum_dtype = "float32"
num_stages = 3 num_stages = 3
threads = 128 threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
accum_dtype, num_stages, threads)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,10 +12,8 @@ def ref_program(x, y): ...@@ -12,10 +12,8 @@ def ref_program(x, y):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype) A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype) B_shared = T.alloc_shared((block_M, block_N), in_dtype)
...@@ -24,7 +22,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -24,7 +22,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared) T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N): for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N]) T.copy(C_shared, C[by * block_M, bx * block_N])
...@@ -41,19 +39,21 @@ def get_configs(M, N): ...@@ -41,19 +39,21 @@ def get_configs(M, N):
def get_best_config(M, N): def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None): def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = AutoTuner.from_kernel( autotuner = (
kernel=kernel, configs=get_configs(M, N)).set_compile_args( AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N))
.set_compile_args(
out_idx=[-1], out_idx=[-1],
target="cuda", target="cuda",
).set_profile_args( )
.set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto, supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=False, skip_check=False,
) )
)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
......
...@@ -6,7 +6,6 @@ from einops import rearrange, repeat ...@@ -6,7 +6,6 @@ from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function): class IndexFirstAxis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, indices): def forward(ctx, input, indices):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
...@@ -15,9 +14,7 @@ class IndexFirstAxis(torch.autograd.Function): ...@@ -15,9 +14,7 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim = other_shape.numel() second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices] # return input[indices]
return torch.gather( return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape)
rearrange(input, "b ... -> b (...)"), 0,
repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
...@@ -40,14 +37,12 @@ index_first_axis = IndexFirstAxis.apply ...@@ -40,14 +37,12 @@ index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function): class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, values, indices, first_axis_dim): def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
assert indices.ndim == 1 assert indices.ndim == 1
assert values.ndim >= 2 assert values.ndim >= 2
output = torch.zeros( output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
...@@ -66,7 +61,6 @@ index_put_first_axis = IndexPutFirstAxis.apply ...@@ -66,7 +61,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function): class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, indices): def forward(ctx, input, indices):
ctx.save_for_backward(indices) ctx.save_for_backward(indices)
...@@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng ...@@ -128,7 +122,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
""" """
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
``` ```
[ [
...@@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng ...@@ -177,9 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
""" """
length = attention_mask_in_length.sum(dim=-1) length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1) seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange( attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
seqlen, device=length.device, dtype=length.dtype).expand(len(length),
seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
......
...@@ -6,11 +6,13 @@ import argparse ...@@ -6,11 +6,13 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v): for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -107,26 +107,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -107,26 +107,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)): for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -135,35 +136,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): ...@@ -135,35 +136,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy( T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
) )
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_atomic_add(batch, }
heads, )
seq_len, def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
dim_qk, sm_scale = (1.0 / dim_qk) ** 0.5
dim_v, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -173,15 +166,15 @@ def flashattn_bwd_atomic_add(batch, ...@@ -173,15 +166,15 @@ def flashattn_bwd_atomic_add(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -201,35 +194,36 @@ def flashattn_bwd_atomic_add(batch, ...@@ -201,35 +194,36 @@ def flashattn_bwd_atomic_add(batch,
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
}) K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -241,29 +235,21 @@ def flashattn_bwd_atomic_add(batch, ...@@ -241,29 +235,21 @@ def flashattn_bwd_atomic_add(batch,
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd return flash_bwd
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_split(batch, }
heads, )
seq_len, def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
dim_qk, sm_scale = (1.0 / dim_qk) ** 0.5
dim_v, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -275,15 +261,15 @@ def flashattn_bwd_split(batch, ...@@ -275,15 +261,15 @@ def flashattn_bwd_split(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -303,37 +289,38 @@ def flashattn_bwd_split(batch, ...@@ -303,37 +289,38 @@ def flashattn_bwd_split(batch,
dv_shared = T.alloc_shared([block_M, dim_v], dtype) dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -346,16 +333,15 @@ def flashattn_bwd_split(batch, ...@@ -346,16 +333,15 @@ def flashattn_bwd_split(batch,
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd return flash_bwd
@torch.compile @torch.compile
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
...@@ -373,7 +359,10 @@ class _attention(torch.autograd.Function): ...@@ -373,7 +359,10 @@ class _attention(torch.autograd.Function):
def backward(ctx, do): def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] (
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV groups = H // HEAD_KV
def maybe_contiguous(x): def maybe_contiguous(x):
...@@ -390,17 +379,8 @@ class _attention(torch.autograd.Function): ...@@ -390,17 +379,8 @@ class _attention(torch.autograd.Function):
if ctx.use_atomic: if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add( kernel = flashattn_bwd_atomic_add(
BATCH, BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
H, )
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
...@@ -413,17 +393,8 @@ class _attention(torch.autograd.Function): ...@@ -413,17 +393,8 @@ class _attention(torch.autograd.Function):
dv = dv.to(torch.float16) dv = dv.to(torch.float16)
else: else:
kernel = flashattn_bwd_split( kernel = flashattn_bwd_split(
BATCH, BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
H, )
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
...@@ -445,53 +416,45 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -445,53 +416,45 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK] # K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V] # V: [B, T, HV, D_V]
# HQ = HKV * groups # HQ = HKV * groups
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1) dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
def main(BATCH: int = 1, def main(
H: int = 32, BATCH: int = 1,
N_CTX: int = 256, H: int = 32,
D_HEAD_QK: int = 192, N_CTX: int = 256,
D_HEAD_V: int = 128, D_HEAD_QK: int = 192,
groups: int = 16, D_HEAD_V: int = 128,
causal: bool = False, groups: int = 16,
use_atomic: bool = True): causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups head_kv = H // groups
K = ( K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
device="cuda").normal_().requires_grad_()) dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups, use_atomic) O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
...@@ -508,7 +471,7 @@ def main(BATCH: int = 1, ...@@ -508,7 +471,7 @@ def main(BATCH: int = 1,
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅') print("All checks passed.✅")
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -528,17 +491,15 @@ def main(BATCH: int = 1, ...@@ -528,17 +491,15 @@ def main(BATCH: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument( parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
# Handle backward compatibility and logic # Handle backward compatibility and logic
...@@ -550,5 +511,4 @@ if __name__ == "__main__": ...@@ -550,5 +511,4 @@ if __name__ == "__main__":
# Default: use atomic # Default: use atomic
use_atomic = True use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
use_atomic)
...@@ -9,11 +9,13 @@ tilelang.disable_cache() ...@@ -9,11 +9,13 @@ tilelang.disable_cache()
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -23,11 +25,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -23,11 +25,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -43,27 +45,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -43,27 +45,23 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead # We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30)) T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30))
T.Cast(accum_dtype, -1e30))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -81,18 +79,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -81,18 +79,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v): for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -101,9 +101,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -101,9 +101,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -112,12 +112,12 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -112,12 +112,12 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)): for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
...@@ -128,9 +128,11 @@ def make_dq_layout(dQ): ...@@ -128,9 +128,11 @@ def make_dq_layout(dQ):
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4, 5], pass_configs={ out_idx=[3, 4, 5],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -141,46 +143,37 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): ...@@ -141,46 +143,37 @@ def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore dV_out: T.Tensor(v_shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk:(bx + 1) * blk, T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :])
by, :])
with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
T.annotate_layout({ T.annotate_layout(
dK: make_dq_layout(dK), {
dV: make_dq_layout(dV), dK: make_dq_layout(dK),
}) dV: make_dq_layout(dV),
T.copy(dK[bz, bx * blk:(bx + 1) * blk, by, :], dK_out[bz, bx * blk:(bx + 1) * blk, }
by, :]) )
T.copy(dV[bz, bx * blk:(bx + 1) * blk, by, :], dV_out[bz, bx * blk:(bx + 1) * blk, T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :])
by, :]) T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :])
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_atomic_add(batch, }
heads, )
seq_len, def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
dim_qk, sm_scale = (1.0 / dim_qk) ** 0.5
dim_v, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -190,15 +183,15 @@ def flashattn_bwd_atomic_add(batch, ...@@ -190,15 +183,15 @@ def flashattn_bwd_atomic_add(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -219,37 +212,38 @@ def flashattn_bwd_atomic_add(batch, ...@@ -219,37 +212,38 @@ def flashattn_bwd_atomic_add(batch,
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
dK: make_dq_layout(dK), dQ: make_dq_layout(dQ),
dV: make_dq_layout(dV), dK: make_dq_layout(dK),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dV: make_dq_layout(dV),
}) K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) )
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -259,33 +253,23 @@ def flashattn_bwd_atomic_add(batch, ...@@ -259,33 +253,23 @@ def flashattn_bwd_atomic_add(batch,
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared) T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared, use_tma=True) T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add( T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add( T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
return flash_bwd return flash_bwd
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_split_novarlen(batch, }
heads, )
seq_len, def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
dim_qk, sm_scale = (1.0 / dim_qk) ** 0.5
dim_v, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -297,15 +281,15 @@ def flashattn_bwd_split_novarlen(batch, ...@@ -297,15 +281,15 @@ def flashattn_bwd_split_novarlen(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -325,37 +309,38 @@ def flashattn_bwd_split_novarlen(batch, ...@@ -325,37 +309,38 @@ def flashattn_bwd_split_novarlen(batch,
dv_shared = T.alloc_shared([block_M, dim_v], dtype) dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype) dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) )
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -368,16 +353,15 @@ def flashattn_bwd_split_novarlen(batch, ...@@ -368,16 +353,15 @@ def flashattn_bwd_split_novarlen(batch,
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M:(by + 1) * block_M, bx // groups, :]) T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd return flash_bwd
@torch.compile @torch.compile
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
...@@ -395,7 +379,10 @@ class _attention(torch.autograd.Function): ...@@ -395,7 +379,10 @@ class _attention(torch.autograd.Function):
def backward(ctx, do): def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] (
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV groups = H // HEAD_KV
def maybe_contiguous(x): def maybe_contiguous(x):
...@@ -412,17 +399,8 @@ class _attention(torch.autograd.Function): ...@@ -412,17 +399,8 @@ class _attention(torch.autograd.Function):
if ctx.use_atomic: if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add( kernel = flashattn_bwd_atomic_add(
BATCH, BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
H, )
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
...@@ -433,17 +411,8 @@ class _attention(torch.autograd.Function): ...@@ -433,17 +411,8 @@ class _attention(torch.autograd.Function):
dq, dk, dv = mod_post(dq, dk, dv) dq, dk, dv = mod_post(dq, dk, dv)
else: else:
kernel = flashattn_bwd_split_novarlen( kernel = flashattn_bwd_split_novarlen(
BATCH, BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
H, )
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
...@@ -451,8 +420,7 @@ class _attention(torch.autograd.Function): ...@@ -451,8 +420,7 @@ class _attention(torch.autograd.Function):
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv) kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0) dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None return dq, dk, dv, None, None, None
...@@ -466,53 +434,45 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -466,53 +434,45 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK] # K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V] # V: [B, T, HV, D_V]
# HQ = HKV * groups # HQ = HKV * groups
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1) dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
def main(BATCH: int = 1, def main(
H: int = 32, BATCH: int = 1,
N_CTX: int = 256, H: int = 32,
D_HEAD_QK: int = 192, N_CTX: int = 256,
D_HEAD_V: int = 128, D_HEAD_QK: int = 192,
groups: int = 16, D_HEAD_V: int = 128,
causal: bool = False, groups: int = 16,
use_atomic: bool = True): causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups head_kv = H // groups
K = ( K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
device="cuda").normal_().requires_grad_()) dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups, use_atomic) O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
...@@ -529,7 +489,7 @@ def main(BATCH: int = 1, ...@@ -529,7 +489,7 @@ def main(BATCH: int = 1,
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅') print("All checks passed.✅")
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -552,17 +512,15 @@ if __name__ == "__main__": ...@@ -552,17 +512,15 @@ if __name__ == "__main__":
print(f"Detected GPU compute capability: {arch}") print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument( parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
# Handle backward compatibility and logic # Handle backward compatibility and logic
...@@ -574,5 +532,4 @@ if __name__ == "__main__": ...@@ -574,5 +532,4 @@ if __name__ == "__main__":
# Default: use atomic # Default: use atomic
use_atomic = True use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
use_atomic)
...@@ -15,32 +15,21 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): ...@@ -15,32 +15,21 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if mode == "full": if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random": elif mode == "random":
lengths = torch.randint( lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third": elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = ( padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
return padding_mask return padding_mask
@tilelang.jit( @tilelang.jit(
out_idx=[5, 6], pass_configs={ out_idx=[5, 6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn_fwd(batch, )
total_q, def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
total_kv, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
N_CTX,
heads,
max_seq_len,
dim_qk,
dim_v,
is_causal,
block_M,
block_N,
groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [total_q, heads, dim_qk] q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk] k_shape = [total_kv, head_kv, dim_qk]
...@@ -51,13 +40,13 @@ def flashattn_fwd(batch, ...@@ -51,13 +40,13 @@ def flashattn_fwd(batch,
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -102,15 +91,17 @@ def flashattn_fwd(batch, ...@@ -102,15 +91,17 @@ def flashattn_fwd(batch,
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i, j] = T.if_then_else(
(bx * block_M + i < q_current_seqlen and (bx * block_M + i >= k * block_N + j)
k * block_N + j < k_current_seqlen), 0, and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen),
T.Cast(accum_dtype, -1e30)) 0,
T.Cast(accum_dtype, -1e30),
)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i, j] = T.if_then_else(
bx * block_M + i < q_current_seqlen and bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)
k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30)) )
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, d in T.Parallel(block_N, dim_v): for i, d in T.Parallel(block_N, dim_v):
if k * block_N + i < k_current_seqlen: if k * block_N + i < k_current_seqlen:
...@@ -148,9 +139,11 @@ def flashattn_fwd(batch, ...@@ -148,9 +139,11 @@ def flashattn_fwd(batch,
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -159,10 +152,10 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): ...@@ -159,10 +152,10 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -201,9 +194,11 @@ def make_dq_layout(dQ): ...@@ -201,9 +194,11 @@ def make_dq_layout(dQ):
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4, 5], pass_configs={ out_idx=[3, 4, 5],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -214,46 +209,39 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): ...@@ -214,46 +209,39 @@ def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore dV_out: T.Tensor(v_shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by): with T.Kernel(T.ceildiv(total_q, blk), heads, threads=128) as (bx, by):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bx * blk:(bx + 1) * blk, by, :], dQ_out[bx * blk:(bx + 1) * blk, by, :]) T.copy(dQ[bx * blk : (bx + 1) * blk, by, :], dQ_out[bx * blk : (bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by): with T.Kernel(T.ceildiv(total_kv, blk), head_kv, threads=128) as (bx, by):
T.annotate_layout({ T.annotate_layout(
dK: make_dq_layout(dK), {
dV: make_dq_layout(dV), dK: make_dq_layout(dK),
}) dV: make_dq_layout(dV),
T.copy(dK[bx * blk:(bx + 1) * blk, by, :], dK_out[bx * blk:(bx + 1) * blk, by, :]) }
T.copy(dV[bx * blk:(bx + 1) * blk, by, :], dV_out[bx * blk:(bx + 1) * blk, by, :]) )
T.copy(dK[bx * blk : (bx + 1) * blk, by, :], dK_out[bx * blk : (bx + 1) * blk, by, :])
T.copy(dV[bx * blk : (bx + 1) * blk, by, :], dV_out[bx * blk : (bx + 1) * blk, by, :])
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_atomic_add(batch, }
total_q, )
total_kv, def flashattn_bwd_atomic_add(
N_CTX, batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1
heads, ):
max_seq_len, sm_scale = (1.0 / dim_qk) ** 0.5
dim_qk, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [total_q, heads, dim_qk] q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk] k_shape = [total_kv, head_kv, dim_qk]
...@@ -264,20 +252,19 @@ def flashattn_bwd_atomic_add(batch, ...@@ -264,20 +252,19 @@ def flashattn_bwd_atomic_add(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype) q = T.alloc_shared([block_N, dim_qk], dtype)
...@@ -303,58 +290,54 @@ def flashattn_bwd_atomic_add(batch, ...@@ -303,58 +290,54 @@ def flashattn_bwd_atomic_add(batch,
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
dK: make_dq_layout(dK), dQ: make_dq_layout(dQ),
dV: make_dq_layout(dV), dK: make_dq_layout(dK),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dV: make_dq_layout(dV),
}) K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared)
K_shared) T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared)
T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.min( loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N) loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy( T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q)
Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and qkT[i, j] = T.if_then_else(
(by * block_M + i < k_current_seqlen and (by * block_M + i <= k_base * block_N + j)
k_base * block_N + j < q_current_seqlen), and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen),
qkT[i, j], 0) qkT[i, j],
0,
)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else( qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) )
T.copy( T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do)
dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
do)
T.clear(dsT) T.clear(dsT)
# dsT: (block_kv, block_q) # dsT: (block_kv, block_q)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
...@@ -364,49 +347,40 @@ def flashattn_bwd_atomic_add(batch, ...@@ -364,49 +347,40 @@ def flashattn_bwd_atomic_add(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared) T.copy(dq, dq_shared)
T.atomic_add( T.atomic_add(
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N, dQ[q_start_idx + k_base * block_N : q_start_idx + k_base * block_N + block_N, bx, :],
bx, :],
dq_shared, dq_shared,
memory_order="relaxed", memory_order="relaxed",
use_tma=True) use_tma=True,
)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add( T.atomic_add(
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, dV[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
bx // groups, :],
dv_shared, dv_shared,
memory_order="relaxed", memory_order="relaxed",
use_tma=True) use_tma=True,
)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add( T.atomic_add(
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M, dK[k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :],
bx // groups, :],
dk_shared, dk_shared,
memory_order="relaxed", memory_order="relaxed",
use_tma=True) use_tma=True,
)
return flash_bwd return flash_bwd
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd_split(batch, }
total_q, )
total_kv, def flashattn_bwd_split(
N_CTX, batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1
heads, ):
max_seq_len, sm_scale = (1.0 / dim_qk) ** 0.5
dim_qk, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
dim_v,
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [total_q, heads, dim_qk] q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk] k_shape = [total_kv, head_kv, dim_qk]
...@@ -419,20 +393,19 @@ def flashattn_bwd_split(batch, ...@@ -419,20 +393,19 @@ def flashattn_bwd_split(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor(do_shape, dtype), # type: ignore dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
heads, T.ceildiv(max_seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype) q = T.alloc_shared([block_N, dim_qk], dtype)
...@@ -457,59 +430,55 @@ def flashattn_bwd_split(batch, ...@@ -457,59 +430,55 @@ def flashattn_bwd_split(batch,
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
k_current_seqlen = k_end_idx - k_start_idx k_current_seqlen = k_end_idx - k_start_idx
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :], T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared)
K_shared) T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared)
T.copy(V[k_start_idx + by * block_M:k_start_idx + (by + 1) * block_M, bx // groups, :],
V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.min( loop_st = T.min(T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen, block_N)) if is_causal else 0
T.floordiv(by * block_M, block_N), T.floordiv(q_current_seqlen,
block_N)) if is_causal else 0
loop_ed = T.ceildiv(q_current_seqlen, block_N) loop_ed = T.ceildiv(q_current_seqlen, block_N)
for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k_base in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
# Note: The padding zero of varlen should be considered in T.copy # Note: The padding zero of varlen should be considered in T.copy
T.copy( T.copy(Q[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], q)
Q[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy( T.copy(dO[q_start_idx + k_base * block_N : q_start_idx + (k_base + 1) * block_N, bx, :], do)
dO[q_start_idx + k_base * block_N:q_start_idx + (k_base + 1) * block_N, bx, :],
do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k_base * block_N:(k_base + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k_base * block_N : (k_base + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else((by * block_M + i <= k_base * block_N + j) and qkT[i, j] = T.if_then_else(
(by * block_M + i < k_current_seqlen and (by * block_M + i <= k_base * block_N + j)
k_base * block_N + j < q_current_seqlen), and (by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen),
qkT[i, j], 0) qkT[i, j],
0,
)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else( qkT[i, j] = T.if_then_else(
by * block_M + i < k_current_seqlen and by * block_M + i < k_current_seqlen and k_base * block_N + j < q_current_seqlen, qkT[i, j], 0
k_base * block_N + j < q_current_seqlen, qkT[i, j], 0) )
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k_base * block_N:(k_base + 1) * block_N], delta) T.copy(Delta[bz, bx, k_base * block_N : (k_base + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -520,62 +489,37 @@ def flashattn_bwd_split(batch, ...@@ -520,62 +489,37 @@ def flashattn_bwd_split(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk): for i, j in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen: if k_base * block_N + i < q_current_seqlen:
T.atomic_add( T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j], memory_order="relaxed")
dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j],
memory_order="relaxed")
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy( T.copy(dv_shared, dV[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :])
dv_shared,
dV[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy( T.copy(dk_shared, dK[bx % groups, k_start_idx + by * block_M : k_start_idx + by * block_M + block_M, bx // groups, :])
dk_shared,
dK[bx % groups, k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
bx // groups, :])
return flash_bwd return flash_bwd
@torch.compile @torch.compile
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, def forward(
q, ctx, q, k, v, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups=1, use_atomic=True
k, ):
v,
seqlens_q,
seqlens_k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal,
groups=1,
use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1] D_HEAD_V = v.shape[-1]
block_M = 128 block_M = 128
block_N = 64 block_N = 64
q_unpad, indices_q, _, _ = unpad_input( q_unpad, indices_q, _, _ = unpad_input(q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
q, (torch.arange(N_CTX, device=q.device).unsqueeze(0) < seqlens_q.unsqueeze(1))) k_unpad, indices_k, _, _ = unpad_input(k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
k_unpad, indices_k, _, _ = unpad_input( v_unpad, _, _, _ = unpad_input(v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
k, (torch.arange(N_CTX, device=k.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
v_unpad, _, _, _ = unpad_input(
v, (torch.arange(N_CTX, device=v.device).unsqueeze(0) < seqlens_k.unsqueeze(1)))
total_q = q_unpad.shape[0] total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0] total_kv = k_unpad.shape[0]
mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
causal, block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX) o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k)
cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH ctx.batch = BATCH
ctx.causal = causal ctx.causal = causal
ctx.use_atomic = use_atomic ctx.use_atomic = use_atomic
...@@ -590,8 +534,7 @@ class _attention(torch.autograd.Function): ...@@ -590,8 +534,7 @@ class _attention(torch.autograd.Function):
N_CTX = do.shape[1] N_CTX = do.shape[1]
q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors q, k, v, o, lse_clone, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
# lse_clone = lse.clone() # lse_clone = lse.clone()
do_unpad, _, _, _ = unpad_input( do_unpad, _, _, _ = unpad_input(do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
do, (torch.arange(N_CTX, device=do.device).unsqueeze(0) < seqlens_q.unsqueeze(1)))
total_q, H, D_HEAD_QK = q.shape total_q, H, D_HEAD_QK = q.shape
total_kv, HEAD_KV, D_HEAD_V = v.shape total_kv, HEAD_KV, D_HEAD_V = v.shape
groups = H // HEAD_KV groups = H // HEAD_KV
...@@ -624,7 +567,8 @@ class _attention(torch.autograd.Function): ...@@ -624,7 +567,8 @@ class _attention(torch.autograd.Function):
block_N, block_N,
threads=256, threads=256,
num_stages=2, num_stages=2,
groups=groups) groups=groups,
)
dq = torch.zeros_like(q, dtype=torch.float32) dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32)
...@@ -645,13 +589,13 @@ class _attention(torch.autograd.Function): ...@@ -645,13 +589,13 @@ class _attention(torch.autograd.Function):
block_N, block_N,
threads=256, threads=256,
num_stages=2, num_stages=2,
groups=groups) groups=groups,
)
dq = torch.zeros_like(q, dtype=torch.float32) dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device)
dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0) dk, dv = dk.sum(0), dv.sum(0)
dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX) dq = pad_input(dq, ctx.indices_q, BATCH, N_CTX)
...@@ -670,15 +614,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): ...@@ -670,15 +614,13 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
# HQ = HKV * groups # HQ = HKV * groups
# To handle precision issue # To handle precision issue
Q, K, V = Q.float(), K.float(), V.float() Q, K, V = Q.float(), K.float(), V.float()
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1) dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if padding_mask is not None: if padding_mask is not None:
scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf")) scores.masked_fill_(rearrange(~padding_mask, "b s -> b 1 1 s"), float("-inf"))
...@@ -686,41 +628,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1): ...@@ -686,41 +628,35 @@ def ref_program(Q, K, V, padding_mask, is_causal, groups=1):
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
if padding_mask is not None: if padding_mask is not None:
output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) output.masked_fill_(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
return output return output
def main(BATCH: int = 1, def main(
H: int = 32, BATCH: int = 1,
N_CTX: int = 256, H: int = 32,
D_HEAD_QK: int = 192, N_CTX: int = 256,
D_HEAD_V: int = 128, D_HEAD_QK: int = 192,
groups: int = 16, D_HEAD_V: int = 128,
causal: bool = False, groups: int = 16,
use_atomic: bool = True): causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups head_kv = H // groups
K = ( K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
device="cuda").normal_().requires_grad_()) dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random") padding_mask = generate_random_padding_mask(N_CTX, BATCH, "cuda", mode="random")
seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_q = padding_mask.sum(dim=-1, dtype=torch.int32)
cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0)) cu_seqlens_q = F.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
...@@ -729,8 +665,7 @@ def main(BATCH: int = 1, ...@@ -729,8 +665,7 @@ def main(BATCH: int = 1,
# In training backward pass, seqlens_k should be the same as seqlens_q # In training backward pass, seqlens_k should be the same as seqlens_q
seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q seqlens_k, cu_seqlens_k, max_seqlen_k = seqlens_q, cu_seqlens_q, max_seqlen_q
O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, O = attention(Q, K, V, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, groups, use_atomic)
max_seqlen_k, causal, groups, use_atomic)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None dK, K.grad = K.grad.clone(), None
...@@ -772,17 +707,15 @@ if __name__ == "__main__": ...@@ -772,17 +707,15 @@ if __name__ == "__main__":
print(f"Detected GPU compute capability: {arch}") print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0" assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument( parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
args = parser.parse_args() args = parser.parse_args()
# Can be set to True/False for testing # Can be set to True/False for testing
args.causal = True args.causal = True
...@@ -796,5 +729,4 @@ if __name__ == "__main__": ...@@ -796,5 +729,4 @@ if __name__ == "__main__":
# Default: use atomic # Default: use atomic
use_atomic = True use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
use_atomic)
...@@ -6,11 +6,13 @@ import argparse ...@@ -6,11 +6,13 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -20,11 +22,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -40,25 +42,21 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -76,18 +74,20 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v): for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -96,9 +96,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -107,32 +107,24 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): ...@@ -107,32 +107,24 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)): for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd(batch, }
heads, )
seq_len, def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
dim_qk, sm_scale = (1.0 / dim_qk) ** 0.5
dim_v, scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
is_causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=1):
sm_scale = (1.0 / dim_qk)**0.5
scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk] q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk]
...@@ -142,15 +134,15 @@ def flashattn_bwd(batch, ...@@ -142,15 +134,15 @@ def flashattn_bwd(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype) K_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -171,45 +163,39 @@ def flashattn_bwd(batch, ...@@ -171,45 +163,39 @@ def flashattn_bwd(batch,
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}) dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) )
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm( T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm( T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1) T.wait_wgmma(1)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -221,18 +207,17 @@ def flashattn_bwd(batch, ...@@ -221,18 +207,17 @@ def flashattn_bwd(batch,
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0) T.wait_wgmma(0)
T.copy(dq, dq_shared) T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd return flash_bwd
@torch.compile @torch.compile
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
...@@ -250,7 +235,10 @@ class _attention(torch.autograd.Function): ...@@ -250,7 +235,10 @@ class _attention(torch.autograd.Function):
def backward(ctx, do): def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape BATCH, N_CTX, H, D_HEAD_QK = q.shape
HEAD_KV, D_HEAD_V, = v.shape[-2], v.shape[-1] (
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV groups = H // HEAD_KV
def maybe_contiguous(x): def maybe_contiguous(x):
...@@ -264,18 +252,7 @@ class _attention(torch.autograd.Function): ...@@ -264,18 +252,7 @@ class _attention(torch.autograd.Function):
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
delta = mod_prep(o, do) delta = mod_prep(o, do)
kernel = flashattn_bwd( kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups)
BATCH,
H,
N_CTX,
D_HEAD_QK,
D_HEAD_V,
ctx.causal,
block_M,
block_N,
threads=256,
num_stages=2,
groups=groups)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
...@@ -298,52 +275,36 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -298,52 +275,36 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D_QK] # K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V] # V: [B, T, HV, D_V]
# HQ = HKV * groups # HQ = HKV * groups
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1) dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
def main(BATCH: int = 1, def main(BATCH: int = 1, H: int = 32, N_CTX: int = 256, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, causal: bool = False):
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half,
device="cuda").normal_().requires_grad_())
head_kv = H // groups head_kv = H // groups
K = ( K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
device="cuda").normal_().requires_grad_()) dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = (
torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
dO = (
torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half,
device="cuda").normal_().requires_grad_())
O = attention(Q, K, V, causal, groups) O = attention(Q, K, V, causal, groups)
O.backward(dO, retain_graph=True) O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None dQ, Q.grad = Q.grad.clone(), None
...@@ -360,7 +321,7 @@ def main(BATCH: int = 1, ...@@ -360,7 +321,7 @@ def main(BATCH: int = 1,
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅') print("All checks passed.✅")
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -380,13 +341,13 @@ def main(BATCH: int = 1, ...@@ -380,13 +341,13 @@ def main(BATCH: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal)
...@@ -9,7 +9,6 @@ from functools import partial ...@@ -9,7 +9,6 @@ from functools import partial
class FlashAttentionTuneSpace: class FlashAttentionTuneSpace:
def __init__( def __init__(
self, self,
block_sizes=(64, 128, 256), block_sizes=(64, 128, 256),
...@@ -40,7 +39,7 @@ def get_configs(user_config=None): ...@@ -40,7 +39,7 @@ def get_configs(user_config=None):
warp_M = block_M // warp_count warp_M = block_M // warp_count
warp_N = block_N // warp_count warp_N = block_N // warp_count
if (warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0): if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0:
continue continue
shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N) shared_mem = 2 * config.dtype_bytes * config.dim * (block_M + block_N)
...@@ -48,31 +47,26 @@ def get_configs(user_config=None): ...@@ -48,31 +47,26 @@ def get_configs(user_config=None):
continue continue
for num_stages in config.num_stages_range: for num_stages in config.num_stages_range:
valid_configs.append({ valid_configs.append(
"block_M": block_M, {
"block_N": block_N, "block_M": block_M,
"num_stages": num_stages, "block_N": block_N,
"threads": threads, "num_stages": num_stages,
}) "threads": threads,
}
)
return valid_configs return valid_configs
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block_N=64, num_stages=0, threads=128):
seq_len, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -90,15 +84,13 @@ def flashattn(batch, ...@@ -90,15 +84,13 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -111,18 +103,18 @@ def flashattn(batch, ...@@ -111,18 +103,18 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -148,18 +140,18 @@ def flashattn(batch, ...@@ -148,18 +140,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -175,25 +167,24 @@ def flashattn(batch, ...@@ -175,25 +167,24 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main return main
...@@ -203,50 +194,34 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -203,50 +194,34 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D] # K: [B, T, HK, D]
# V: [B, T, HV, D] # V: [B, T, HV, D]
# HQ = HKV * groups # HQ = HKV * groups
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1) dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
def main(batch: int = 1, def main(
heads: int = 64, batch: int = 1, heads: int = 64, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 16, tune: bool = False
seq_len: int = 4096, ):
dim: int = 128,
is_causal: bool = False,
groups: int = 16,
tune: bool = False):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=64, block_N=64, num_stages=2, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -270,12 +245,12 @@ def main(batch: int = 1, ...@@ -270,12 +245,12 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='heads') parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
...@@ -24,9 +24,11 @@ def get_configs(): ...@@ -24,9 +24,11 @@ def get_configs():
rep=10, rep=10,
) )
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn( def flashattn(
batch, batch,
heads, heads,
...@@ -39,7 +41,7 @@ def flashattn( ...@@ -39,7 +41,7 @@ def flashattn(
num_stages=0, num_stages=0,
threads=128, threads=128,
): ):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -57,15 +59,13 @@ def flashattn( ...@@ -57,15 +59,13 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -78,18 +78,18 @@ def flashattn( ...@@ -78,18 +78,18 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -115,18 +115,18 @@ def flashattn( ...@@ -115,18 +115,18 @@ def flashattn(
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -142,30 +142,30 @@ def flashattn( ...@@ -142,30 +142,30 @@ def flashattn(
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined( for k in T.Pipelined(
loop_range, loop_range,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main return main
...@@ -175,23 +175,21 @@ def ref_program(Q, K, V, is_causal, groups=1): ...@@ -175,23 +175,21 @@ def ref_program(Q, K, V, is_causal, groups=1):
# K: [B, T, HK, D] # K: [B, T, HK, D]
# V: [B, T, HV, D] # V: [B, T, HV, D]
# HQ = HKV * groups # HQ = HKV * groups
assert Q.size(2) == K.size( assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1) dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -209,18 +207,8 @@ def main( ...@@ -209,18 +207,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -244,12 +232,12 @@ def main( ...@@ -244,12 +232,12 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='heads') parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups, args.tune)
...@@ -10,14 +10,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv ...@@ -10,14 +10,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref( def attention_ref(
q, q,
k, k,
v, v,
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
causal=False, causal=False,
window_size=(-1, -1), window_size=(-1, -1),
upcast=True, upcast=True,
): ):
if causal: if causal:
window_size = (window_size[0], 0) window_size = (window_size[0], 0)
...@@ -26,7 +26,7 @@ def attention_ref( ...@@ -26,7 +26,7 @@ def attention_ref(
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
b, T, Hq, D = q.shape b, T, Hq, D = q.shape
S = k.shape[1] S = k.shape[1]
scale = (1.0 / D)**0.5 scale = (1.0 / D) ** 0.5
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2]) k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k) scores = torch.einsum("bthd,bshd->bhts", q, k)
...@@ -54,21 +54,13 @@ def attention_ref( ...@@ -54,21 +54,13 @@ def attention_ref(
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch_size, )
groups, def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
UQ, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
UKV,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [UQ, heads, dim] q_shape = [UQ, heads, dim]
kv_shape = [UKV, head_kv, dim] kv_shape = [UKV, head_kv, dim]
...@@ -78,17 +70,15 @@ def flashattn(batch_size, ...@@ -78,17 +70,15 @@ def flashattn(batch_size,
@T.prim_func @T.prim_func
def main( def main(
Q_unpad: T.Tensor(q_shape, dtype), Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(kv_shape, dtype), K_unpad: T.Tensor(kv_shape, dtype),
V_unpad: T.Tensor(kv_shape, dtype), V_unpad: T.Tensor(kv_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32, max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype), Output_unpad: T.Tensor(o_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz):
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -102,10 +92,12 @@ def flashattn(batch_size, ...@@ -102,10 +92,12 @@ def flashattn(batch_size,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({ T.annotate_layout(
O_shared: tilelang.layout.make_swizzled_layout(O_shared), {
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
}
)
batch_idx = bz batch_idx = bz
head_idx = by head_idx = by
...@@ -119,36 +111,34 @@ def flashattn(batch_size, ...@@ -119,36 +111,34 @@ def flashattn(batch_size,
q_current_seqlen = q_end_idx - q_start_idx q_current_seqlen = q_end_idx - q_start_idx
kv_current_seqlen = k_end_idx - kv_start_idx kv_current_seqlen = k_end_idx - kv_start_idx
T.copy( T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared)
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min( T.min(T.ceildiv(q_current_seqlen + (bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
T.ceildiv(q_current_seqlen + if is_causal
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N)) else T.ceildiv(kv_current_seqlen, block_N)
if is_causal else T.ceildiv(kv_current_seqlen, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(K_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], K_shared)
K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, acc_s[i, j] = T.if_then_else(
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or (bx * block_M + i < k * block_N + j)
(bx * block_M + i >= q_current_seqlen or or (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen),
k * block_N + j >= kv_current_seqlen), -1e9, 0) -1e9,
0,
)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or acc_s[i, j] = T.if_then_else(
k * block_N + j >= kv_current_seqlen), -1e9, (bx * block_M + i >= q_current_seqlen or k * block_N + j >= kv_current_seqlen), -1e9, 0
0) )
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -170,9 +160,7 @@ def flashattn(batch_size, ...@@ -170,9 +160,7 @@ def flashattn(batch_size,
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(V_unpad[kv_start_idx + k * block_N : kv_start_idx + (k + 1) * block_N, kv_head_idx, :], V_shared)
V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
kv_head_idx, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
...@@ -187,13 +175,9 @@ def flashattn(batch_size, ...@@ -187,13 +175,9 @@ def flashattn(batch_size,
return main return main
def main(batch: int = 1, def main(
heads: int = 64, batch: int = 1, heads: int = 64, q_seqlen: int = 2048, k_seqlen: int = 2048, dim: int = 128, groups: int = 16, is_causal: bool = False
q_seqlen: int = 2048, ):
k_seqlen: int = 2048,
dim: int = 128,
groups: int = 16,
is_causal: bool = False):
assert heads % groups == 0, "heads must be divisible by groups" assert heads % groups == 0, "heads must be divisible by groups"
flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim
...@@ -231,24 +215,12 @@ def main(batch: int = 1, ...@@ -231,24 +215,12 @@ def main(batch: int = 1,
output_pad_fn, output_pad_fn,
_, _,
_, _,
) = generate_qkv( ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0] UQ = q_unpad.shape[0]
UKV = k_unpad.shape[0] UKV = k_unpad.shape[0]
kernel = flashattn( kernel = flashattn(batch, groups, UQ, UKV, heads, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
batch,
groups,
UQ,
UKV,
heads,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -263,23 +235,19 @@ def main(batch: int = 1, ...@@ -263,23 +235,19 @@ def main(batch: int = 1,
) )
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench( latency = do_bench(lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q), _n_warmup=5, _n_repeat=5)
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
_n_warmup=5,
_n_repeat=5)
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='query heads') parser.add_argument("--heads", type=int, default=64, help="query heads")
parser.add_argument('--groups', type=int, default=16, help='groups') parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length') parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length")
parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length') parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length")
parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument("--dim", type=int, default=128, help="head dim")
parser.add_argument('--is_causal', action='store_true', help='causal attention') parser.add_argument("--is_causal", action="store_true", help="causal attention")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal)
args.is_causal)
...@@ -7,22 +7,24 @@ import argparse ...@@ -7,22 +7,24 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -39,28 +41,24 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -39,28 +41,24 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
# T.copy(Q_shared, Q_local) # T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim): # for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale # Q_local[i, j] *= scale
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -78,18 +76,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -78,18 +76,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -98,9 +98,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -98,9 +98,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -109,26 +109,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -109,26 +109,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -137,40 +138,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -137,40 +138,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy( T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :], dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
) )
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -194,38 +197,39 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -194,38 +197,39 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dv_shared = T.alloc_shared([block_M, dim], dtype) dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) }
T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) )
T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0)
# We don't need to handle OOB positions for non-causal cases, # We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here. # since OOB values won't affect other positions here.
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -238,14 +242,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -238,14 +242,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j]) T.atomic_add(dQ[bz, bx, k * block_N + i, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :])
return flash_bwd return flash_bwd
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal): def forward(ctx, q, k, v, causal):
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
...@@ -287,15 +290,15 @@ attention = _attention.apply ...@@ -287,15 +290,15 @@ attention = _attention.apply
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(2) seq_len = Q.size(2)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output return output
...@@ -310,9 +313,7 @@ def main( ...@@ -310,9 +313,7 @@ def main(
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, H, N_CTX, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
...@@ -353,10 +354,10 @@ def main( ...@@ -353,10 +354,10 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
...@@ -7,22 +7,24 @@ import argparse ...@@ -7,22 +7,24 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -38,25 +40,21 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -38,25 +40,21 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -74,18 +72,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -74,18 +72,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -94,9 +94,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -94,9 +94,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -105,26 +105,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -105,26 +105,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -133,40 +134,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -133,40 +134,42 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy( T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :], dQ[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
) )
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -190,35 +193,36 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -190,35 +193,36 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dv_shared = T.alloc_shared([block_M, dim], dtype) dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
}) dQ: make_dq_layout(dQ),
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) }
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) )
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0)
# We don't need to handle OOB positions for non-causal cases, # We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here. # since OOB values won't affect other positions here.
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -231,14 +235,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -231,14 +235,13 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :])
return flash_bwd return flash_bwd
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal): def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
...@@ -280,15 +283,15 @@ attention = _attention.apply ...@@ -280,15 +283,15 @@ attention = _attention.apply
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -303,9 +306,7 @@ def main( ...@@ -303,9 +306,7 @@ def main(
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
...@@ -344,10 +345,10 @@ def main( ...@@ -344,10 +345,10 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
...@@ -7,22 +7,24 @@ import argparse ...@@ -7,22 +7,24 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -38,26 +40,22 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -38,26 +40,22 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -75,18 +73,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -75,18 +73,20 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -95,9 +95,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -95,9 +95,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -106,37 +106,39 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -106,37 +106,39 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -161,49 +163,43 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -161,49 +163,43 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
dk_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype)
dq_shared = T.alloc_shared([block_N, dim], accum_dtype) dq_shared = T.alloc_shared([block_N, dim], accum_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}) dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
}
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) )
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=2): for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm( T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm( T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
wg_wait=-1)
T.wait_wgmma(1) T.wait_wgmma(1)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0)
# We don't need to handle OOB positions for non-causal cases, # We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here. # since OOB values won't affect other positions here.
T.wait_wgmma(0) T.wait_wgmma(0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -214,17 +210,16 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -214,17 +210,16 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1)
T.wait_wgmma(0) T.wait_wgmma(0)
T.copy(dq, dq_shared) T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N:(k + 1) * block_N, bx, :], dq_shared) T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :])
return flash_bwd return flash_bwd
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal): def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
...@@ -266,15 +261,15 @@ attention = _attention.apply ...@@ -266,15 +261,15 @@ attention = _attention.apply
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -289,9 +284,7 @@ def main( ...@@ -289,9 +284,7 @@ def main(
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
if causal: if causal:
total_flops *= 0.5 total_flops *= 0.5
Q = ( Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_()
torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_() K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
...@@ -311,7 +304,7 @@ def main( ...@@ -311,7 +304,7 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅') print("All checks passed.✅")
def run(): def run():
O_ref.backward(dO, retain_graph=True) O_ref.backward(dO, retain_graph=True)
...@@ -329,10 +322,10 @@ def main( ...@@ -329,10 +322,10 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument("--d_head", type=int, default=64, help="Head dimension")
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
...@@ -15,20 +15,13 @@ def get_configs(): ...@@ -15,20 +15,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
seq_q, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16" dtype = "float16"
...@@ -48,7 +41,7 @@ def flashattn(batch, ...@@ -48,7 +41,7 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
...@@ -70,18 +63,18 @@ def flashattn(batch, ...@@ -70,18 +63,18 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -110,18 +103,18 @@ def flashattn(batch, ...@@ -110,18 +103,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -137,43 +130,42 @@ def flashattn(batch, ...@@ -137,43 +130,42 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min( T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv( if is_causal
(bx + 1) * block_M + else T.ceildiv(seq_kv, block_N)
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_q = Q.size(2) seq_q = Q.size(2)
seq_kv = K.size(2) seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output return output
...@@ -191,18 +183,8 @@ def main( ...@@ -191,18 +183,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -227,12 +209,12 @@ def main( ...@@ -227,12 +209,12 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=1, help='heads') parser.add_argument("--heads", type=int, default=1, help="heads")
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') parser.add_argument("--seq_q", type=int, default=256, help="query sequence length")
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') parser.add_argument("--seq_kv", type=int, default=256, help="key/value sequence length")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal', default=False) parser.add_argument("--is_causal", action="store_true", help="causal", default=False)
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
...@@ -15,20 +15,13 @@ def get_configs(): ...@@ -15,20 +15,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256):
seq_q, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16" dtype = "float16"
...@@ -48,7 +41,7 @@ def flashattn(batch, ...@@ -48,7 +41,7 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
...@@ -70,18 +63,18 @@ def flashattn(batch, ...@@ -70,18 +63,18 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -108,18 +101,18 @@ def flashattn(batch, ...@@ -108,18 +101,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -135,48 +128,48 @@ def flashattn(batch, ...@@ -135,48 +128,48 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min( T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv( if is_causal
(bx + 1) * block_M + else T.ceildiv(seq_kv, block_N)
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) )
for k in T.Pipelined( for k in T.Pipelined(
loop_range, loop_range,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) scores = torch.einsum("bhqd,bhkd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_q = Q.size(2) seq_q = Q.size(2)
seq_kv = K.size(2) seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) output = torch.einsum("bhqk,bhkd->bhqd", attention_weights, V)
return output return output
...@@ -194,18 +187,8 @@ def main( ...@@ -194,18 +187,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -230,12 +213,12 @@ def main( ...@@ -230,12 +213,12 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_q', type=int, default=4096, help='query sequence length') parser.add_argument("--seq_q", type=int, default=4096, help="query sequence length")
parser.add_argument('--seq_kv', type=int, default=4096, help='key/value sequence length') parser.add_argument("--seq_kv", type=int, default=4096, help="key/value sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
...@@ -15,19 +15,13 @@ def get_configs(): ...@@ -15,19 +15,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
seq_len, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -43,16 +37,14 @@ def flashattn(batch, ...@@ -43,16 +37,14 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
# We shall fill -inf for OOB positions # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -65,18 +57,18 @@ def flashattn(batch, ...@@ -65,18 +57,18 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -102,18 +94,18 @@ def flashattn(batch, ...@@ -102,18 +94,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -129,40 +121,39 @@ def flashattn(batch, ...@@ -129,40 +121,39 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main return main
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -179,17 +170,8 @@ def main( ...@@ -179,17 +170,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -213,11 +195,11 @@ def main( ...@@ -213,11 +195,11 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
...@@ -15,19 +15,13 @@ def get_configs(): ...@@ -15,19 +15,13 @@ def get_configs():
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch, )
heads, def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256):
seq_len, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -43,16 +37,14 @@ def flashattn(batch, ...@@ -43,16 +37,14 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
# We shall fill -inf for OOB positions # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -65,18 +57,18 @@ def flashattn(batch, ...@@ -65,18 +57,18 @@ def flashattn(batch,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -102,18 +94,18 @@ def flashattn(batch, ...@@ -102,18 +94,18 @@ def flashattn(batch,
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -129,45 +121,45 @@ def flashattn(batch, ...@@ -129,45 +121,45 @@ def flashattn(batch,
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined( for k in T.Pipelined(
loop_range, loop_range,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
return main return main
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -184,17 +176,8 @@ def main( ...@@ -184,17 +176,8 @@ def main(
if is_causal: if is_causal:
total_flops *= 0.5 total_flops *= 0.5
if (not tune): if not tune:
kernel = flashattn( kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -218,11 +201,11 @@ def main( ...@@ -218,11 +201,11 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment