Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -11,14 +11,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv ...@@ -11,14 +11,14 @@ from varlen_utils import generate_random_padding_mask, generate_qkv
def attention_ref( def attention_ref(
q, q,
k, k,
v, v,
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
upcast=True, upcast=True,
): ):
""" """
Arguments: Arguments:
...@@ -47,7 +47,7 @@ def attention_ref( ...@@ -47,7 +47,7 @@ def attention_ref(
if upcast: if upcast:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1] dim = q.shape[-1]
scale = (1.0 / dim)**0.5 # log2(e) scale = (1.0 / dim) ** 0.5 # log2(e)
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k) scores = torch.einsum("bthd,bshd->bhts", q, k)
...@@ -68,20 +68,13 @@ def attention_ref( ...@@ -68,20 +68,13 @@ def attention_ref(
@tilelang.jit( @tilelang.jit(
out_idx=[6], pass_configs={ out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
def flashattn(batch_size, )
UQ, def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32):
UKV, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=32):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [UQ, heads, dim] q_shape = [UQ, heads, dim]
k_shape = [UKV, heads, dim] k_shape = [UKV, heads, dim]
v_shape = [UKV, heads, dim] v_shape = [UKV, heads, dim]
...@@ -92,17 +85,15 @@ def flashattn(batch_size, ...@@ -92,17 +85,15 @@ def flashattn(batch_size,
@T.prim_func @T.prim_func
def main( def main(
Q_unpad: T.Tensor(q_shape, dtype), Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(k_shape, dtype), K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Tensor(v_shape, dtype), V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32, max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype), Output_unpad: T.Tensor(o_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz):
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype, "shared") Q_shared = T.alloc_shared([block_M, dim], dtype, "shared")
K_shared = T.alloc_shared([block_N, dim], dtype, "shared") K_shared = T.alloc_shared([block_N, dim], dtype, "shared")
V_shared = T.alloc_shared([block_N, dim], dtype, "shared") V_shared = T.alloc_shared([block_N, dim], dtype, "shared")
...@@ -151,15 +142,17 @@ def flashattn(batch_size, ...@@ -151,15 +142,17 @@ def flashattn(batch_size,
K_shared[i, d] = 0 K_shared[i, d] = 0
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= q_current_seqlen or (bx * block_M + i >= k * block_N + j)
k * block_N + j >= k_current_seqlen), and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0) -T.infinity(acc_s.dtype),
0,
)
else: else:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or acc_s[i, j] = T.if_then_else(
k * block_N + j >= k_current_seqlen), (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0
-T.infinity(acc_s.dtype), 0) )
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): ...@@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
output_pad_fn, output_pad_fn,
dq_pad_fn, dq_pad_fn,
dk_pad_fn, dk_pad_fn,
) = generate_qkv( ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0] # unpadded query length UQ = q_unpad.shape[0] # unpadded query length
UK = k_unpad.shape[0] # unpadded key length UK = k_unpad.shape[0] # unpadded key length
...@@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): ...@@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='heads') parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument('--seq_len', type=int, default=2048, help='sequence length') parser.add_argument("--seq_len", type=int, default=2048, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim) main(args.batch, args.heads, args.seq_len, args.dim)
...@@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined(): ...@@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined(): def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main( example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd(): def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main( example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
......
...@@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): ...@@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if mode == "full": if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random": elif mode == "random":
lengths = torch.randint( lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third": elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = ( padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
return padding_mask return padding_mask
def generate_qkv(q, def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False):
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, d) q: (batch_size, seqlen_q, nheads, d)
...@@ -39,15 +31,12 @@ def generate_qkv(q, ...@@ -39,15 +31,12 @@ def generate_qkv(q,
if query_padding_mask is not None: if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
)
else: else:
q_unpad = rearrange(q, "b s h d -> (b s) h d") q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
max_seqlen_q = seqlen_q max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
output_unpad, "(b s) h d -> b s h d", b=batch_size)
if key_padding_mask is not None: if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
...@@ -55,8 +44,7 @@ def generate_qkv(q, ...@@ -55,8 +44,7 @@ def generate_qkv(q,
else: else:
k_unpad = rearrange(k, "b s h d -> (b s) h d") k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange( cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
max_seqlen_k = seqlen_k max_seqlen_k = seqlen_k
if qkvpacked: if qkvpacked:
...@@ -67,8 +55,7 @@ def generate_qkv(q, ...@@ -67,8 +55,7 @@ def generate_qkv(q,
if query_padding_mask is not None: if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else: else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return ( return (
qkv_unpad.detach().requires_grad_(), qkv_unpad.detach().requires_grad_(),
cu_seqlens_q, cu_seqlens_q,
...@@ -84,8 +71,7 @@ def generate_qkv(q, ...@@ -84,8 +71,7 @@ def generate_qkv(q,
if key_padding_mask is not None: if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else: else:
dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return ( return (
q_unpad.detach().requires_grad_(), q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(),
......
...@@ -20,13 +20,7 @@ def get_configs(): ...@@ -20,13 +20,7 @@ def get_configs():
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{ configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs return configs
...@@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]: ...@@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]:
# TODO(lei): fix warp specialized and tma lower pass # TODO(lei): fix warp specialized and tma lower pass
def get_pass_configs(): def get_pass_configs():
return { return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
}
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads):
threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim] shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim]
...@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid = by hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype))
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -127,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -127,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid = bz sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head, :], K_shared) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
K_shared,
)
T.copy( T.copy(
mask[bid, (seqlen_kv // num_split) * sid + mask[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head], mask_local) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
],
mask_local,
)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype))
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V[bid, (seqlen_kv // num_split) * sid + V[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head, :], V_shared) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
V_shared,
)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
...@@ -216,14 +217,13 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -216,14 +217,13 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if i < valid_block_H: if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i] glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :])
sid, :])
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
lse_max_local = T.alloc_fragment([128], accum_dtype) lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_fragment([128], accum_dtype) scale_local = T.alloc_fragment([128], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), {
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id) lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), # lse_local: (local_id, thread_id)
}) lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
...@@ -263,26 +265,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -263,26 +265,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
@T.prim_func @T.prim_func
def flashattn_gqa_decode_split( def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn_split(Q, K, V, mask, glse, Output_partial) flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def flashattn_gqa_decode_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn(Q, K, V, mask, Output) flash_attn(Q, K, V, mask, Output)
...@@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): ...@@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim = query.shape[-1] dim = query.shape[-1]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
if mask is not None: if mask is not None:
mask = rearrange(mask, 'b s h -> b h s') mask = rearrange(mask, "b s h -> b h s")
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
...@@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask): ...@@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv = K.size(1) seqlen_kv = K.size(1)
num_head_groups = nheads // groups num_head_groups = nheads // groups
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16)
device="cuda",
dtype=torch.float16)
acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups), scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
device="cuda",
dtype=torch.float)
scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
...@@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask): ...@@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale Q_ = Q * scale
Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups)
for ks in range(num_split): for ks in range(num_split):
acc_o.fill_(0) acc_o.fill_(0)
logsum.fill_(0) logsum.fill_(0)
scores_max.fill_(float('-inf')) scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float('-inf')) scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)): for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0) acc_s.fill_(0)
acc_s = torch.einsum('bghd,bkhd->bghk', Q_, acc_s = torch.einsum(
K[:, (seqlen_kv // num_split) * ks + "bghd,bkhd->bghk",
i * block_N:(seqlen_kv // num_split) * ks + Q_,
(i + 1) * block_N, :, :]) # [batch, nheads, block_N] K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
if mask is not None: if mask is not None:
mask_local = mask[:, (seqlen_kv // num_split) * ks + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] mask_local = rearrange(mask_local, "b s h -> b h s")
mask_local = rearrange(mask_local, 'b s h -> b h s')
mask_local = mask_local.unsqueeze(1) mask_local = mask_local.unsqueeze(1)
acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) acc_s = acc_s.masked_fill(mask_local == 0, float("-inf"))
scores_max_prev = scores_max scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
...@@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask): ...@@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum( acc_o += torch.einsum(
'bghk,bkhd->bghd', acc_s_cast, "bghk,bkhd->bghd",
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + acc_s_cast,
(i + 1) * block_N, :, :]) V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False) scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum logsum = logsum * scores_scale + scores_sum
acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') acc_o_out = rearrange(acc_o, "b g h d->b (h g) d")
logsum_out = rearrange(logsum, 'b g h->b (h g)') logsum_out = rearrange(logsum, "b g h->b (h g)")
acc_o_out /= logsum_out[:, :, None] acc_o_out /= logsum_out[:, :, None]
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)")
gacc_o[ks, :, :, :] = acc_o_out gacc_o[ks, :, :, :] = acc_o_out
glogsum[ks, :, :] = logsum_out glogsum[ks, :, :] = logsum_out
...@@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"): ...@@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double() x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
if denominator == 0: if denominator == 0:
print_red_warning(f'{name} all zero') print_red_warning(f"{name} all zero")
return 1 return 1
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return sim return sim
...@@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"): ...@@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name) sim = calc_sim(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}') print_red_warning(f"{name} Error: {diff}")
if assert_: if assert_:
raise AssertionError(f'{name} Error: {diff}') raise AssertionError(f"{name} Error: {diff}")
else: else:
if print_: if print_:
print(f'passed: {name} diff={diff}') print(f"passed: {name} diff={diff}")
def main(batch: int = 1, def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False):
heads: int = 32,
groups: int = 8,
kv_seqlen: int = 8192,
dim: int = 128,
tune: bool = False):
batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
qk_flops = 2 * batch * heads * kv_seqlen * dim qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
if (not tune): if not tune:
config, sm_version = get_heuristic_config() config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
...@@ -497,11 +485,11 @@ def main(batch: int = 1, ...@@ -497,11 +485,11 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--groups', type=int, default=8, help='groups') parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
...@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1: if n_rep == 1:
return hidden_states return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
...@@ -74,14 +73,9 @@ def _fwd_inner( ...@@ -74,14 +73,9 @@ def _fwd_inner(
return m_i, l_i, acc return m_i, l_i, acc
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"],
for num_warps in [4, 8]\
for num_stages in [2, 4]\
],
key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'],
) )
@triton.jit @triton.jit
def _fwd_kernel_varlen( def _fwd_kernel_varlen(
...@@ -107,13 +101,12 @@ def _fwd_kernel_varlen( ...@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od, stride_od,
stride_sb, stride_sb,
stride_sh, stride_sh,
stride_sn, #bmask shape [b, q_h, seq/BLOCK_N] stride_sn, # bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size: tl.constexpr, gqa_group_size: tl.constexpr,
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_D: tl.constexpr,
): ):
off_z = tl.program_id(0) off_z = tl.program_id(0)
off_h_for_kv = tl.program_id(1) off_h_for_kv = tl.program_id(1)
off_h_q = off_h_for_kv * gqa_group_size off_h_q = off_h_for_kv * gqa_group_size
...@@ -134,8 +127,7 @@ def _fwd_kernel_varlen( ...@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh
mask_h = offs_h < gqa_group_size mask_h = offs_h < gqa_group_size
q = tl.load( q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
if s_aux is not None: if s_aux is not None:
sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32)
...@@ -189,14 +181,12 @@ def _fwd_kernel_varlen( ...@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc = acc.to(O.dtype.element_ty) acc = acc.to(O.dtype.element_ty)
tl.store( tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None])
O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od,
acc,
mask=mask_h[:, None])
def get_configs(): def get_configs():
import itertools import itertools
block_N = [64, 128] block_N = [64, 128]
block_H = [64] block_H = [64]
num_split = [1] num_split = [1]
...@@ -204,31 +194,16 @@ def get_configs(): ...@@ -204,31 +194,16 @@ def get_configs():
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{ configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs return configs
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch, def flashattn(
heads, batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
k_heads, ):
max_seqlen_kv, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
total_seqlen_k,
dim,
has_sink,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim] shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim]
...@@ -243,13 +218,13 @@ def flashattn(batch, ...@@ -243,13 +218,13 @@ def flashattn(batch,
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], "float32"),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -268,13 +243,15 @@ def flashattn(batch, ...@@ -268,13 +243,15 @@ def flashattn(batch,
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], "float32") s_aux_shared = T.alloc_shared([block_H], "float32")
T.annotate_layout({ T.annotate_layout(
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), {
# K_shared: tilelang.layout.make_swizzled_layout(K_shared), # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared), # K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared), # V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared), # O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) # S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
bid = bx bid = bx
hid = by hid = by
...@@ -284,7 +261,7 @@ def flashattn(batch, ...@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k = cu_seqlens_k[bid + 1] cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -292,15 +269,13 @@ def flashattn(batch, ...@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N) # loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared)
K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype)) # -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -320,12 +295,11 @@ def flashattn(batch, ...@@ -320,12 +295,11 @@ def flashattn(batch,
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared)
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink: if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i] logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
...@@ -338,20 +312,19 @@ def flashattn(batch, ...@@ -338,20 +312,19 @@ def flashattn(batch,
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
# T.copy(S_fragment, S_shared) # T.copy(S_fragment, S_shared)
T.copy(S_shared[:valid_block_H, :], S[bid, T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.prim_func @T.prim_func
def flashattn_gqa_decode_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], "float32"),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S)
...@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang( ...@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size = q_h // k_h gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q) O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
dtype=Q.dtype,
device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
if use_per_kv_head_sparse_index: if use_per_kv_head_sparse_index:
...@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode( ...@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H = 64 BLOCK_H = 64
O = torch.zeros_like(Q) O = torch.zeros_like(Q)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device)
dtype=Q.dtype,
device=Q.device)
def grid(META): def grid(META):
return (batch, k_h) return (batch, k_h)
...@@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args): ...@@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
# For decode, query is just 1 token per batch # For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V # Convert to varlen format for K, V
...@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
# Generate cumulative sequence lengths # Generate cumulative sequence lengths
cu_seqlens_k = torch.arange( cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
max_seqlen_k = k_seqlen max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}") print(f"q shape: {q.shape}")
...@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens, q_h, head_size = q.shape num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Test our decode kernel # Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode( O_triton, S_triton = flash_attn_with_attn_pool_decode(
...@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args): ...@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q, q,
k_varlen, k_varlen,
...@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel=tl_kernel, tl_kernel=tl_kernel,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Compute torch reference # Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args): ...@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if sink is None: if sink is None:
# Standard scaled dot-product attention # Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1) attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args): ...@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores = torch.exp(logits - logits_or_sinks_max) unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling # Compute attention score pooling
attn_score_pooled = torch.max_pool2d( attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen] attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16) ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang) print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled) print("attn_score_pooled", attn_score_pooled)
...@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args): ...@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -616,7 +575,7 @@ def test_varlen_decode_main(args): ...@@ -616,7 +575,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences # Generate variable length k sequences
...@@ -624,7 +583,7 @@ def test_varlen_decode_main(args): ...@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print(f"k_seqlens: {k_seqlens}") print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -634,9 +593,9 @@ def test_varlen_decode_main(args): ...@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print(f"cu_seqlens_k: {cu_seqlens_k}") print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode # Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -649,8 +608,7 @@ def test_varlen_decode_main(args): ...@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Test our decode kernel # Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode( O_triton, S_triton = flash_attn_with_attn_pool_decode(
...@@ -663,7 +621,8 @@ def test_varlen_decode_main(args): ...@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode, q_decode,
k_varlen, k_varlen,
...@@ -678,9 +637,7 @@ def test_varlen_decode_main(args): ...@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel=tl_kernel, tl_kernel=tl_kernel,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Create torch reference - pad tensors for comparison # Create torch reference - pad tensors for comparison
k_padded_list = [] k_padded_list = []
...@@ -694,8 +651,8 @@ def test_varlen_decode_main(args): ...@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end = cu_seqlens_k[i + 1] k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k # Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end] k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end]
...@@ -704,10 +661,8 @@ def test_varlen_decode_main(args): ...@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list.append(v_padded) v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack( k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -717,20 +672,17 @@ def test_varlen_decode_main(args): ...@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print(f"v_padded_batched shape: {v_padded_batched.shape}") print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference # Compute torch reference
k_repeat = repeat_kv(k_padded_batched, k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None: if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose( attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf') attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
...@@ -743,13 +695,12 @@ def test_varlen_decode_main(args): ...@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf') logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -765,8 +716,7 @@ def test_varlen_decode_main(args): ...@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights[i, :, :, actual_k_len:] = 0.0 attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
...@@ -775,7 +725,8 @@ def test_varlen_decode_main(args): ...@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights.squeeze(2), # [b, q_heads, max_seqlen] attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}") print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}") print(f"O_tilelang shape: {O_tilelang.shape}")
...@@ -791,22 +742,16 @@ def test_varlen_decode_main(args): ...@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max( max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
assert torch.allclose( f"Score mismatch: {max_diff_s_tl.item()}"
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" )
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args): ...@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k[batch_size] = total_k_tokens cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors # Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values") print(" Using sink attention with sink values")
print("Setup complete:") print("Setup complete:")
...@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Benchmark # Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...") print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
...@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args): ...@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark # Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...") print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, triton_time = do_bench(
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, flash_attn_with_attn_pool_decode,
block_size) q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}") print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument( parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument('--block_size', type=int, default=64, help='Block size for computation') parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument( parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument( parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths') parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument(
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
args = parser.parse_args() args = parser.parse_args()
args.test_sink = True args.test_sink = True
args.test_varlen = False args.test_varlen = False
args.dtype = 'float16' args.dtype = "float16"
args.num_split = 1 args.num_split = 1
if args.benchmark: if args.benchmark:
......
...@@ -10,6 +10,7 @@ torch.manual_seed(0) ...@@ -10,6 +10,7 @@ torch.manual_seed(0)
def get_configs(): def get_configs():
import itertools import itertools
block_N = [64, 128] block_N = [64, 128]
block_H = [64] block_H = [64]
num_split = [1] num_split = [1]
...@@ -17,32 +18,28 @@ def get_configs(): ...@@ -17,32 +18,28 @@ def get_configs():
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{ configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs return configs
# @autotune(configs=get_configs(), warmup=10, rep=10) # @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch, def flashattn(
heads, batch,
k_heads, heads,
max_seqlen_kv, k_heads,
total_seqlen_k, max_seqlen_kv,
dim, total_seqlen_k,
has_sink, dim,
page_block_size, has_sink,
block_N=128, page_block_size,
block_H=64, block_N=128,
num_split=1, block_H=64,
num_stages=1, num_split=1,
threads=128): num_stages=1,
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) threads=128,
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim] shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim]
...@@ -51,21 +48,23 @@ def flashattn(batch, ...@@ -51,21 +48,23 @@ def flashattn(batch,
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // k_heads kv_group_num = heads // k_heads
assert page_block_size >= block_N and page_block_size % block_N == 0, "page_block_size must be larger than block_N and a multiple of block_N" assert page_block_size >= block_N and page_block_size % block_N == 0, (
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case # TODO: check if max_seqlen_kv is correct for varlen case
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], "float32"),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -91,7 +90,7 @@ def flashattn(batch, ...@@ -91,7 +90,7 @@ def flashattn(batch,
cur_end_k = cu_seqlens_k[bid + 1] cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -99,15 +98,12 @@ def flashattn(batch, ...@@ -99,15 +98,12 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N) # loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
k * block_N) % page_block_size T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared)
T.copy(K[cur_start_k + k_start:cur_start_k + k_start + block_N, cur_kv_head, :],
K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -127,14 +123,12 @@ def flashattn(batch, ...@@ -127,14 +123,12 @@ def flashattn(batch,
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + ( v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
k * block_N) % page_block_size T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared)
T.copy(V[cur_start_k + v_start:cur_start_k + v_start + block_N, cur_kv_head, :],
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink: if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i] logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
...@@ -144,20 +138,19 @@ def flashattn(batch, ...@@ -144,20 +138,19 @@ def flashattn(batch,
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid, T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.prim_func @T.prim_func
def flashattn_gqa_decode_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], "float32"),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S) flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S)
...@@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang( ...@@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size = q_h // k_h gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q) O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
dtype=Q.dtype,
device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table)
if use_per_kv_head_sparse_index: if use_per_kv_head_sparse_index:
...@@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args): ...@@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
# For decode, query is just 1 token per batch # For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V # Convert to varlen format for K, V
...@@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous() v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
# Generate cumulative sequence lengths # Generate cumulative sequence lengths
cu_seqlens_k = torch.arange( cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
max_seqlen_k = k_seqlen max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}") print(f"q shape: {q.shape}")
...@@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args): ...@@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args):
num_tokens, q_h, head_size = q.shape num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
args.test_sink, page_block_size)
block_table = torch.zeros( block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_cnt = 0 block_cnt = 0
for i in range(batch): for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
...@@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args): ...@@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q, q,
k_varlen, k_varlen,
...@@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args):
block_table=block_table, block_table=block_table,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Compute torch reference # Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args): ...@@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args):
if sink is None: if sink is None:
# Standard scaled dot-product attention # Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1) attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args): ...@@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores = torch.exp(logits - logits_or_sinks_max) unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling # Compute attention score pooling
attn_score_pooled = torch.max_pool2d( attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen] attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16) ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang) print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled) print("attn_score_pooled", attn_score_pooled)
...@@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args): ...@@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args):
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -368,7 +348,7 @@ def test_varlen_decode_main(args): ...@@ -368,7 +348,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences # Generate variable length k sequences
...@@ -376,7 +356,7 @@ def test_varlen_decode_main(args): ...@@ -376,7 +356,7 @@ def test_varlen_decode_main(args):
print(f"k_seqlens: {k_seqlens}") print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -386,9 +366,9 @@ def test_varlen_decode_main(args): ...@@ -386,9 +366,9 @@ def test_varlen_decode_main(args):
print(f"cu_seqlens_k: {cu_seqlens_k}") print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode # Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -401,11 +381,9 @@ def test_varlen_decode_main(args): ...@@ -401,11 +381,9 @@ def test_varlen_decode_main(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
args.test_sink, page_block_size)
block_table = torch.zeros( block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_cnt = 0 block_cnt = 0
for i in range(batch): for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
...@@ -425,7 +403,8 @@ def test_varlen_decode_main(args): ...@@ -425,7 +403,8 @@ def test_varlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode, q_decode,
k_varlen, k_varlen,
...@@ -441,9 +420,7 @@ def test_varlen_decode_main(args): ...@@ -441,9 +420,7 @@ def test_varlen_decode_main(args):
block_table=block_table, block_table=block_table,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Create torch reference - pad tensors for comparison # Create torch reference - pad tensors for comparison
k_padded_list = [] k_padded_list = []
...@@ -457,8 +434,8 @@ def test_varlen_decode_main(args): ...@@ -457,8 +434,8 @@ def test_varlen_decode_main(args):
k_end = cu_seqlens_k[i + 1] k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k # Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end] k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end]
...@@ -467,10 +444,8 @@ def test_varlen_decode_main(args): ...@@ -467,10 +444,8 @@ def test_varlen_decode_main(args):
v_padded_list.append(v_padded) v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack( k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -480,20 +455,17 @@ def test_varlen_decode_main(args): ...@@ -480,20 +455,17 @@ def test_varlen_decode_main(args):
print(f"v_padded_batched shape: {v_padded_batched.shape}") print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference # Compute torch reference
k_repeat = repeat_kv(k_padded_batched, k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None: if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose( attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf') attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
...@@ -506,13 +478,12 @@ def test_varlen_decode_main(args): ...@@ -506,13 +478,12 @@ def test_varlen_decode_main(args):
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf') logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -528,8 +499,7 @@ def test_varlen_decode_main(args): ...@@ -528,8 +499,7 @@ def test_varlen_decode_main(args):
attn_weights[i, :, :, actual_k_len:] = 0.0 attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
...@@ -538,7 +508,8 @@ def test_varlen_decode_main(args): ...@@ -538,7 +508,8 @@ def test_varlen_decode_main(args):
attn_weights.squeeze(2), # [b, q_heads, max_seqlen] attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}") print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}") print(f"O_tilelang shape: {O_tilelang.shape}")
...@@ -554,22 +525,16 @@ def test_varlen_decode_main(args): ...@@ -554,22 +525,16 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max( max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
assert torch.allclose( f"Score mismatch: {max_diff_s_tl.item()}"
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" )
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args): ...@@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k[batch_size] = total_k_tokens cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors # Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values") print(" Using sink attention with sink values")
print("Setup complete:") print("Setup complete:")
...@@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args): ...@@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
args.test_sink, page_block_size)
block_table = torch.zeros( block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_cnt = 0 block_cnt = 0
for i in range(batch): for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item() cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
...@@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args): ...@@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark # Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...") print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, triton_time = do_bench(
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, flash_attn_with_attn_pool_decode,
block_size) q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}") print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument( parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
parser.add_argument('--block_size', type=int, default=128, help='Block size for computation') parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument( parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument( parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths') parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument( parser.add_argument("--page_block_size", type=int, default=128, help="Page block size")
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
parser.add_argument('--page_block_size', type=int, default=128, help='Page block size')
args = parser.parse_args() args = parser.parse_args()
args.test_sink = True args.test_sink = True
args.test_varlen = True args.test_varlen = True
args.dtype = 'float16' args.dtype = "float16"
args.num_split = 1 args.num_split = 1
if args.benchmark: if args.benchmark:
......
...@@ -10,7 +10,7 @@ num_split = 4 ...@@ -10,7 +10,7 @@ num_split = 4
@tilelang.jit(out_idx=[5]) @tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim] part_shape = [batch, seqlen_q, heads, num_split, dim]
...@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32, bid: T.int32,
sid: T.int32, sid: T.int32,
): ):
T.copy( T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared)
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared)
# TODO: Handle causal split case # TODO: Handle causal split case
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -52,20 +49,18 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -52,20 +49,18 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32, bid: T.int32,
sid: T.int32, sid: T.int32,
): ):
T.copy( T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared)
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -91,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -91,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz):
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -128,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -128,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now # disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle causal split case # TODO: Handle causal split case
loop_range = ( loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N))
(mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( if is_causal
(seqlen_kv // num_split), block_N)) else T.ceildiv((seqlen_kv // num_split), block_N)
)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy( T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True)
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype) po_local = T.alloc_fragment([block_M, dim], dtype)
...@@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local = T.alloc_fragment([block_M], accum_dtype) lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({ T.annotate_layout(
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), {
o_shared: tilelang.layout.make_swizzled_layout(o_shared), o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
po_shared: tilelang.layout.make_swizzled_layout(po_shared), o_shared: tilelang.layout.make_swizzled_layout(o_shared),
}) po_shared: tilelang.layout.make_swizzled_layout(po_shared),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
T.copy(glse[ T.copy(
bz, glse[
by, bz,
:, by,
bx * block_M:(bx + 1) * block_M, :,
], lse_local) bx * block_M : (bx + 1) * block_M,
],
lse_local,
)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split): for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split) T.copy(lse_local[k, :], lse_local_split)
...@@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2): for k in T.Pipelined(num_split, num_stages=2):
T.copy( T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True)
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(po_shared, po_local) T.copy(po_shared, po_local)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i] lse_local_split[i] = lse_local[k, i]
...@@ -207,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -207,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i] o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared) T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func @T.prim_func
def flashattn_mha_inference( def flashattn_mha_inference(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Tensor(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
...@@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def ref_program(Q, K, V, glse, Output_partial, causal): def ref_program(Q, K, V, glse, Output_partial, causal):
assert causal is False assert causal is False
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N = 128 block_N = 128
seqlen_kv = K.size(1) seqlen_kv = K.size(1)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
...@@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal): ...@@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for ks in range(num_split): for ks in range(num_split):
acc_o.fill_(0) acc_o.fill_(0)
logsum.fill_(0) logsum.fill_(0)
scores_max.fill_(float('-inf')) scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float('-inf')) scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)): for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0) acc_s.fill_(0)
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, acc_s = torch.einsum(
K[:, (seqlen_kv // num_split) * ks + "bqhd,bkhd->bhqk",
i * block_N:(seqlen_kv // num_split) * ks + Q_,
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max) scores_scale = torch.exp2(scores_max_prev - scores_max)
...@@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal): ...@@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum( acc_o += torch.einsum(
'bhqk,bkhd->bqhd', acc_s_cast, "bhqk,bkhd->bqhd",
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + acc_s_cast,
(i + 1) * block_N, :, :]) V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False) scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2) acc_o /= logsum[:, :, :, None].transpose(1, 2)
...@@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o[ks, :, :, :, :] = acc_o gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0, return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
......
...@@ -9,17 +9,18 @@ from example_fusedmoe_torch import * ...@@ -9,17 +9,18 @@ from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden, def moe_forward_tilelang_shared(
d_expert, d_hidden,
n_shared_experts, d_expert,
dtype, n_shared_experts,
num_tokens, dtype,
block_token=128, num_tokens,
block_dhidden=128, block_token=128,
block_dexpert=128, block_dhidden=128,
threads=256, block_dexpert=128,
num_stages=1): threads=256,
num_stages=1,
):
scale = 1.44269504 # log2(e) scale = 1.44269504 # log2(e)
# Parameters # Parameters
...@@ -36,17 +37,15 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -36,17 +37,15 @@ def moe_forward_tilelang_shared(d_hidden,
@T.prim_func @T.prim_func
def kernel_shared( def kernel_shared(
input: T.Tensor(input_shape, dtype), # type: ignore input: T.Tensor(input_shape, dtype), # type: ignore
shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore
shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore
shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore
up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore output: T.Tensor(input_shape, dtype), # type: ignore
): ):
# Step 1: Compute gate and up logits # Step 1: Compute gate and up logits
with T.Kernel( with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert),
threads=threads) as (bx, by):
# Split the block to shared experts and routed experts # Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
...@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product # Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * ( gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits # Step 2: Compute down logits
with T.Kernel( with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden),
threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
...@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden,
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(d_hidden, def moe_forward_tilelang_routed(
d_expert, d_hidden,
n_routed_experts, d_expert,
dtype, n_routed_experts,
group_sum, dtype,
group_count, group_sum,
block_token=128, group_count,
block_dhidden=128, block_token=128,
block_dexpert=128, block_dhidden=128,
threads=256, block_dexpert=128,
num_stages=1, threads=256,
k_pack=1, num_stages=1,
coalesced_width=None): k_pack=1,
coalesced_width=None,
):
scale = 1.44269504 # log2(e) scale = 1.44269504 # log2(e)
# Parameters # Parameters
...@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = (group_sum) routed_expert_weights_shape = group_sum
group_sizes_shape = (n_routed_experts) group_sizes_shape = n_routed_experts
@T.prim_func @T.prim_func
def kernel( def kernel(
input: T.Tensor(input_shape, dtype), # type: ignore input: T.Tensor(input_shape, dtype), # type: ignore
routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore output: T.Tensor(input_shape, dtype), # type: ignore
): ):
# Step 1: Compute gate and up logits # Step 1: Compute gate and up logits
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
...@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx] cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]] cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
cur_group_idx[0]] actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local) T.clear(gate_logits_local)
T.clear(up_logits_local) T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy( T.copy(
input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared, input_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
)
T.copy( T.copy(
routed_expert_gate[cur_group_idx[0], routed_expert_gate[
by * block_dexpert:(by + 1) * block_dexpert, cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
k * block_dhidden:(k + 1) * block_dhidden], ],
routed_expert_gate_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_gate_shared, routed_expert_gate_shared,
gate_logits_local, coalesced_width=coalesced_width,
k_pack=k_pack, )
transpose_B=True) T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True)
T.copy( T.copy(
routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, routed_expert_up[
k * block_dhidden:(k + 1) * block_dhidden], cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared, routed_expert_up_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
T.gemm( )
input_shared, T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True)
routed_expert_up_shared,
up_logits_local,
k_pack=k_pack,
transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * ( gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
...@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx] cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]] cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
cur_group_idx[0]] actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local) T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy( T.copy(
up_logits[m_start:m_start + block_token, up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
k * block_dexpert:(k + 1) * block_dexpert],
up_logits_shared, up_logits_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
)
T.copy( T.copy(
routed_expert_down[cur_group_idx[0], routed_expert_down[
by * block_dhidden:(by + 1) * block_dhidden, cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
k * block_dexpert:(k + 1) * block_dexpert], ],
routed_expert_down_shared,
coalesced_width=coalesced_width)
T.gemm(
up_logits_shared,
routed_expert_down_shared, routed_expert_down_shared,
output_local, coalesced_width=coalesced_width,
k_pack=k_pack, )
transpose_B=True) T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden): for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows: if i < actual_rows:
output[m_start + i, by * block_dhidden + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel return kernel
class Expert(nn.Module): class Expert(nn.Module):
def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None):
def __init__(self,
config: Dict,
gate: torch.Tensor,
up: torch.Tensor,
down: torch.Tensor,
d_expert: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.act_fn = nn.SiLU() self.act_fn = nn.SiLU()
...@@ -294,14 +265,13 @@ class Expert(nn.Module): ...@@ -294,14 +265,13 @@ class Expert(nn.Module):
class MoEGate(nn.Module): class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict): def __init__(self, config: Dict, weights: Dict):
super().__init__() super().__init__()
self.top_k: int = config["n_experts_per_token"] self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"] self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"] self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights['router.weight'].t() self.W_g_weight = weights["router.weight"].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight logits = x @ self.W_g_weight
...@@ -312,76 +282,69 @@ class MoEGate(nn.Module): ...@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class MoE(nn.Module): class MoE(nn.Module):
def __init__(
def __init__(self, self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128
config: Dict, ):
shared_kernel: tilelang.JITKernel,
routed_kernel: tilelang.JITKernel,
weights: Dict,
padding_M: int = 128):
super().__init__() super().__init__()
self.config = config self.config = config
self.shared_kernel = shared_kernel self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel self.routed_kernel = routed_kernel
self.padding_M = padding_M self.padding_M = padding_M
self.experts = nn.ModuleList([ self.experts = nn.ModuleList(
Expert( [
config, Expert(
gate=weights[f'experts.{i}.0.weight'], config,
up=weights[f'experts.{i}.1.weight'], gate=weights[f"experts.{i}.0.weight"],
down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) up=weights[f"experts.{i}.1.weight"],
]) down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device) self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"] shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert( self.shared_expert = Expert(
config=config, config=config,
gate=weights['shared_experts.0.weight'], gate=weights["shared_experts.0.weight"],
up=weights['shared_experts.1.weight'], up=weights["shared_experts.1.weight"],
down=weights['shared_experts.2.weight'], down=weights["shared_experts.2.weight"],
d_expert=shared_expert_dim).to(self.device) d_expert=shared_expert_dim,
).to(self.device)
self.expert_cache = torch.zeros( self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]), (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device) self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
dim=0) self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts],
dim=0)
self.stacked_expert_tokens = torch.empty( self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
self.config["d_hidden"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
self.stacked_expert_weights = torch.empty( self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.stacked_expert_tokens_idxs = torch.empty( self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
dtype=torch.int64, )
device=self.device)
self.up_logits_shared = torch.empty( self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]), (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.expert_output_shared = torch.empty( self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]), (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.up_logits_routed = torch.empty( self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
self.config["d_expert"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
self.expert_output_routed = torch.empty( self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
self.config["d_hidden"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -413,22 +376,20 @@ class MoE(nn.Module): ...@@ -413,22 +376,20 @@ class MoE(nn.Module):
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor( group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))] group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)): for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
(counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128 block_token = 128
M = math.ceil( M = (
self.config["batch_size"] * self.config["seq_len"] * math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)] group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M): for bx in range(M):
...@@ -437,8 +398,7 @@ class MoE(nn.Module): ...@@ -437,8 +398,7 @@ class MoE(nn.Module):
if m_start_padded >= group_padded_offsets[i]: if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor( group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution # Multi-stream execution
...@@ -448,11 +408,19 @@ class MoE(nn.Module): ...@@ -448,11 +408,19 @@ class MoE(nn.Module):
with torch.cuda.stream(routed_stream): with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM # Tilelang version: Grouped GEMM
self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, self.routed_kernel(
self.stacked_expert_w_up, self.stacked_expert_w_down, self.stacked_expert_tokens,
self.stacked_expert_weights, group_sizes, group_offset, self.stacked_expert_w_gate,
group_padded_offsets, group_idx_for_bx, self.up_logits_routed, self.stacked_expert_w_up,
self.expert_output_routed) self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# Scatter reduce # Scatter reduce
self.expert_cache = torch.scatter_reduce( self.expert_cache = torch.scatter_reduce(
...@@ -460,14 +428,19 @@ class MoE(nn.Module): ...@@ -460,14 +428,19 @@ class MoE(nn.Module):
0, 0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed, self.expert_output_routed,
reduce='sum') reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape) routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream): with torch.cuda.stream(shared_stream):
self.shared_kernel(
self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, x_flat,
self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, self.shared_expert.W_gate_weight,
self.up_logits_shared, self.expert_output_shared) self.shared_expert.W_up_weight,
self.shared_expert.W_down_weight,
self.up_logits_shared,
self.expert_output_shared,
)
shared_output = self.expert_output_shared.view(*orig_shape) shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
config["d_expert"], config["d_expert"],
config["n_shared_experts"], config["n_shared_experts"],
dtype=dtype_str, dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"]) num_tokens=config["batch_size"] * config["seq_len"],
)
routed_kernel = moe_forward_tilelang_routed( routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"], config["d_hidden"],
config["d_expert"], config["d_expert"],
...@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads=256, threads=256,
num_stages=1, num_stages=1,
k_pack=1, k_pack=1,
coalesced_width=2) coalesced_width=2,
)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
...@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return output return output
def main(d_hidden=7168, def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192):
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
config = { config = {
"dhidden": d_hidden, "dhidden": d_hidden,
"dexpert": d_expert, "dexpert": d_expert,
...@@ -536,7 +505,7 @@ def main(d_hidden=7168, ...@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken": n_experts_per_token, "nexpertspertoken": n_experts_per_token,
"bs": batch_size, "bs": batch_size,
"seqlen": seq_len, "seqlen": seq_len,
"seed": 81394 "seed": 81394,
} }
data = generate_input(**config) data = generate_input(**config)
......
...@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional ...@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch # Reference code in PyTorch
class ExpertTorch(nn.Module): class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None): def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module): ...@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class MoEGateTorch(nn.Module): class MoEGateTorch(nn.Module):
def __init__(self, config: Dict): def __init__(self, config: Dict):
super().__init__() super().__init__()
self.top_k: int = config["n_experts_per_token"] self.top_k: int = config["n_experts_per_token"]
...@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module): ...@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class MoETorch(nn.Module): class MoETorch(nn.Module):
def __init__(self, config: Dict): def __init__(self, config: Dict):
super().__init__() super().__init__()
self.config = config self.config = config
self.experts = nn.ModuleList( self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
[ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config) self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"] shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
...@@ -67,8 +63,7 @@ class MoETorch(nn.Module): ...@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return routed_output + shared_output return routed_output + shared_output
@torch.no_grad() @torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x) expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
...@@ -91,8 +86,7 @@ class MoETorch(nn.Module): ...@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out = expert(expert_tokens) expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_( expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache return expert_cache
...@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe = MoETorch(config) moe = MoETorch(config)
# Fill in the given weights of the model # Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
for i in range(num_experts): for i in range(num_experts):
gate_proj_weight = weights[f'experts.{i}.0.weight'] gate_proj_weight = weights[f"experts.{i}.0.weight"]
up_proj_weight = weights[f'experts.{i}.1.weight'] up_proj_weight = weights[f"experts.{i}.1.weight"]
down_proj_weight = weights[f'experts.{i}.2.weight'] down_proj_weight = weights[f"experts.{i}.2.weight"]
# Transpose weights to match expected shape for nn.Linear # Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
output = moe(input_tensor) output = moe(input_tensor)
...@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code # Input generation for the reference code
def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, def generate_input(
nexpertspertoken: int, bs: int, seqlen: int, dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
seed: int) -> Tuple[torch.Tensor, Dict, Dict]: ) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly. # Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden d_hidden = dhidden
d_expert = dexpert d_expert = dexpert
...@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper ...@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len": seq_len, "seq_len": seq_len,
} }
gen = torch.Generator(device='cuda') gen = torch.Generator(device="cuda")
gen.manual_seed(seed) gen.manual_seed(seed)
num_experts = n_routed_experts num_experts = n_routed_experts
expert_dim = d_expert expert_dim = d_expert
weights = {} weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden), input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
device='cuda',
dtype=torch.float16,
generator=gen).contiguous()
# Initialize router weights # Initialize router weights
weights['router.weight'] = torch.randn( weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
(num_experts, d_hidden), device="cuda", dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts): for i in range(num_experts):
weights[f'experts.{i}.0.weight'] = torch.randn( weights[f"experts.{i}.0.weight"] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16, (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(expert_dim) ) / math.sqrt(expert_dim)
weights[f'experts.{i}.1.weight'] = torch.randn( weights[f"experts.{i}.1.weight"] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16, (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(expert_dim) ) / math.sqrt(expert_dim)
weights[f'experts.{i}.2.weight'] = torch.randn( weights[f"experts.{i}.2.weight"] = torch.randn(
(expert_dim, d_hidden), device='cuda', dtype=torch.float16, (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(d_hidden) ) / math.sqrt(d_hidden)
weights['shared_experts.0.weight'] = torch.randn( weights["shared_experts.0.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
device='cuda', ) / math.sqrt(expert_dim * n_shared_experts)
dtype=torch.float16, weights["shared_experts.1.weight"] = torch.randn(
generator=gen) / math.sqrt(expert_dim * n_shared_experts) (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
weights['shared_experts.1.weight'] = torch.randn( ) / math.sqrt(expert_dim * n_shared_experts)
(d_hidden, expert_dim * n_shared_experts), weights["shared_experts.2.weight"] = torch.randn(
device='cuda', (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
dtype=torch.float16, ) / math.sqrt(d_hidden)
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
return (input_tensor, weights, config) return (input_tensor, weights, config)
......
...@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang ...@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang(): def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main( example_fusedmoe_tilelang.main(
d_hidden=1024, d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024
d_expert=256, )
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True) ...@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__, flush=True) print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError: except ImportError:
...@@ -49,6 +50,7 @@ def prepare_input( ...@@ -49,6 +50,7 @@ def prepare_input(
G = F.logsigmoid(G) G = F.logsigmoid(G)
try: try:
from fla.ops.utils.cumsum import chunk_local_cumsum from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size) G = chunk_local_cumsum(G, chunk_size)
except ImportError: except ImportError:
print("fla not found, skip cumsum") print("fla not found, skip cumsum")
...@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( ...@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV = dv.shape[-1] DV = dv.shape[-1]
block_S = 64 block_S = 64
BS = S // block_S BS = S // block_S
dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( dh, dh0, dv2 = (
(B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
torch.empty((B, H, DK, DV), dtype=state_dtype),
torch.empty((B, S, H, DV), dtype=output_dtype),
)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
...@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( ...@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for i_s in range(BS - 1, -1, -1): for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g: if use_g:
for i_bh in range(B * H): for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S): for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
i_h] <= 0: dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
dv_tmp[i_b, i_s2,
i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] -
G[i_b, i_s * block_S + i_s2, i_h])
else: else:
dv_tmp[i_b, i_s2, i_h, :] = 0 dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
if use_g: if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :] G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H): for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
for i_s2 in range(block_S): for i_s2 in range(block_S):
for i_k in range(DK): for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale Q_tmp *= scale
W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
...@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
@T.prim_func @T.prim_func
def kernel( def kernel(
# Input # Input
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype), h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype), dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype), dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype), dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output # Output
dh: T.Tensor(dh_shape, dtype=output_dtype), dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype), dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype), dv2: T.Tensor(dv2_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), {
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
}) W_shared: tilelang.layout.make_swizzled_layout(W_shared),
}
)
if use_final_state_gradient: if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment) T.copy(b_dh_shared, b_dh_fragment)
else: else:
T.clear(b_dh_fragment) T.clear(b_dh_fragment)
...@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh # Store the updated dh
T.copy(b_dh_fragment, b_dh_shared) T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Update dv # Update dv
T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g: if use_g:
T.copy( T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh],
G_shared,
disable_tma=True)
T.copy(G_shared, G_fragment) T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1] G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0]) G_last_local_exp[0] = T.exp(G_last_local[0])
...@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0): # with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then(): with T.Then():
dv_fragment[i_s2, dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else(): with T.Else():
dv_fragment[i_s2, i_v] = 0 dv_fragment[i_s2, i_v] = 0
T.copy( T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2) T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv # Store the updated dv
T.copy(dv_fragment, dv_shared) T.copy(dv_fragment, dv_shared)
T.copy( T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
# Update dh # Update dh
T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment) T.clear(Q_fragment)
if use_g: if use_g:
...@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for i_s2, i_k in T.Parallel(block_S, DK): for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy( T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment) T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
...@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state: if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel return kernel
...@@ -444,44 +437,61 @@ def run_test( ...@@ -444,44 +437,61 @@ def run_test(
num_stages=0, num_stages=0,
use_torch=False, use_torch=False,
): ):
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, W, G, h0, dht, dO, dv = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype)) getattr(torch, accum_dtype),
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype),
getattr(torch, output_dtype), getattr(torch, state_dtype),
getattr(torch, gate_dtype), )
getattr(torch, state_dtype)) dh_ref, dh0_ref, dv2_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
# fla ref # fla ref
print("fla running...", flush=True) print("fla running...", flush=True)
if use_g: if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
scale)
else: else:
G = G.fill_(0) G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
scale)
# tilelang # tilelang
print("tilelang running...", flush=True) print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
accum_dtype, gate_dtype, state_dtype, B,
chunk_size, scale, use_g, use_initial_state, S,
use_final_state_gradient, block_DV, threads, H,
num_stages) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
# kernel = tilelang.compile(program) # kernel = tilelang.compile(program)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench( fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms") print(f"fla time: {fla_time} ms")
...@@ -496,19 +506,47 @@ def run_test( ...@@ -496,19 +506,47 @@ def run_test(
print("torch running...", flush=True) print("torch running...", flush=True)
if use_g: if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, Q,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), K,
getattr(torch, accum_dtype), getattr(torch, W,
gate_dtype), getattr(torch, state_dtype)) G,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda() dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda()
else: else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, Q,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), K,
getattr(torch, accum_dtype), getattr(torch, W,
gate_dtype), getattr(torch, state_dtype)) None,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda() dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda()
......
...@@ -10,6 +10,7 @@ from tilelang.autotuner import autotune ...@@ -10,6 +10,7 @@ from tilelang.autotuner import autotune
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError: except ImportError:
...@@ -56,6 +57,7 @@ def prepare_input( ...@@ -56,6 +57,7 @@ def prepare_input(
G = F.logsigmoid(G) G = F.logsigmoid(G)
try: try:
from fla.ops.utils.cumsum import chunk_local_cumsum from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size) G = chunk_local_cumsum(G, chunk_size)
except ImportError: except ImportError:
print("fla not found, skip cumsum") print("fla not found, skip cumsum")
...@@ -83,18 +85,14 @@ def prepare_output( ...@@ -83,18 +85,14 @@ def prepare_output(
def get_configs(): def get_configs():
import itertools import itertools
block_DK = [32, 64, 128] block_DK = [32, 64, 128]
block_DV = [32, 64, 128] block_DV = [32, 64, 128]
threads = [128, 256] threads = [128, 256]
num_stages = [1, 2, 3] num_stages = [1, 2, 3]
_configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) _configs = list(itertools.product(block_DK, block_DV, threads, num_stages))
configs = [{ configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs]
'block_DK': c[0],
'block_DV': c[1],
'threads': c[2],
'num_stages': c[3]
} for c in _configs]
return configs return configs
...@@ -137,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -137,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype), U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype), h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype), final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype), V_new: T.Tensor(V_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({ T.annotate_layout(
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), {
U_shared: tilelang.layout.make_swizzled_layout(U_shared), b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), U_shared: tilelang.layout.make_swizzled_layout(U_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), W_shared: tilelang.layout.make_swizzled_layout(W_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}) G_shared: tilelang.layout.make_swizzled_layout(G_shared),
}
)
T.use_swizzle(10) T.use_swizzle(10)
if use_initial_state: if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment) T.copy(b_h_shared, b_h_fragment)
else: else:
T.clear(b_h_fragment) T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue # Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Recurrence # Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S # U - W * S
T.copy( T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared)
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U_shared, U_fragment) T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
...@@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new # Save V_new
if save_new_value: if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared) T.copy(V_new_fragment, dst=V_new_shared)
T.copy( T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g # use_g
if use_g: if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
...@@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then(): with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2( V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
(G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695) (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695
)
with T.Else(): with T.Else():
V_new_fragment[i_s2, i_v] = 0 V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp2(G_last_local[0] * 1.442695) G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
...@@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state # Save final state
if store_final_state: if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel return kernel
...@@ -279,17 +276,24 @@ def run_test( ...@@ -279,17 +276,24 @@ def run_test(
threads=128, threads=128,
num_stages=0, num_stages=0,
): ):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, K, W, U, G, initial_state = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype)) DK,
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, DV,
getattr(torch, output_dtype), chunk_size,
getattr(torch, state_dtype)) getattr(torch, input_dtype),
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, state_dtype)) getattr(torch, gate_dtype),
)
h_ref, final_state_ref, V_new_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
# fla ref # fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
...@@ -300,13 +304,27 @@ def run_test( ...@@ -300,13 +304,27 @@ def run_test(
initial_state=initial_state, initial_state=initial_state,
output_final_state=store_final_state, output_final_state=store_final_state,
chunk_size=chunk_size, chunk_size=chunk_size,
save_new_value=save_new_value) save_new_value=save_new_value,
)
# tilelang # tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_fwd_h(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
use_g, use_initial_state, store_final_state, S,
save_new_value) H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line # (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source()) # print("CUDA Code:\n", kernel.get_kernel_source())
...@@ -320,19 +338,15 @@ def run_test( ...@@ -320,19 +338,15 @@ def run_test(
initial_state=initial_state, initial_state=initial_state,
output_final_state=store_final_state, output_final_state=store_final_state,
chunk_size=chunk_size, chunk_size=chunk_size,
save_new_value=save_new_value) save_new_value=save_new_value,
)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state) tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness # check correctness
try: try:
h_ref_fp32 = h_ref.to(torch.float32) h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar( assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False)
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √") print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗") print("tilelang chunk gated delta rule fwd h failed ✗")
...@@ -346,7 +360,8 @@ def run_test( ...@@ -346,7 +360,8 @@ def run_test(
final_state_tilelang_fp32, final_state_tilelang_fp32,
eps=1e-5, eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state", name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False) raise_assert=False,
)
print("tilelang chunk gated delta rule fwd final_state passed √") print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗") print("tilelang chunk gated delta rule fwd final_state failed ✗")
...@@ -355,12 +370,7 @@ def run_test( ...@@ -355,12 +370,7 @@ def run_test(
try: try:
V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar( assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False)
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √") print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗") print("tilelang chunk gated delta rule fwd V_new failed ✗")
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError: except ImportError:
...@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( ...@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o(
@T.prim_func @T.prim_func
def kernel( def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype), HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype), O: T.Tensor(O_shape, dtype=output_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh):
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
...@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o( ...@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({ T.annotate_layout(
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared), H_shared: tilelang.layout.make_swizzled_layout(H_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}) O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.clear(A_fragment) T.clear(A_fragment)
T.clear(O_fragment) T.clear(O_fragment)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared)
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
Q_shared) T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
...@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o( ...@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0): with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
G_diff_local[i_s1, i_s2])
with T.Else(): with T.Else():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
...@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o( ...@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared)
V_shared)
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment) T.gemm(A_shared, V_shared, O_fragment)
...@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o( ...@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared) T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
bv * block_DV:(bv + 1) * block_DV])
return kernel return kernel
...@@ -191,8 +183,9 @@ def run_test( ...@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch = getattr(torch, output_dtype) output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype) accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype) gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, Q, K, V, HIDDEN, G = prepare_input(
output_dtype_torch, accum_dtype_torch, gate_dtype_torch) B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch
)
scale = 1.0 / DK**0.5 scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
...@@ -200,9 +193,25 @@ def run_test( ...@@ -200,9 +193,25 @@ def run_test(
block_S = chunk_size block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_fwd_o(
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, B,
threads, num_stages) S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) O_tilelang = kernel(Q, K, V, HIDDEN, G)
try: try:
......
...@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4 ...@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError: except ImportError:
...@@ -108,10 +109,8 @@ def prepare_output( ...@@ -108,10 +109,8 @@ def prepare_output(
@tilelang.jit( @tilelang.jit(
out_idx=[-4, -3, -2, -1], out_idx=[-4, -3, -2, -1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, )
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_o_bwd_dqkwg( def tilelang_chunk_o_bwd_dqkwg(
# task config # task config
B, B,
...@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=input_dtype), h: T.Tensor(h_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype), dO: T.Tensor(dO_shape, dtype=input_dtype),
dh: T.Tensor(dh_shape, dtype=input_dtype), dh: T.Tensor(dh_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype), dv: T.Tensor(dv_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
# output # output
dq: T.Tensor(dq_shape, dtype=output_dtype), dq: T.Tensor(dq_shape, dtype=output_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dw: T.Tensor(dw_shape, dtype=output_dtype), dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype), dg: T.Tensor(dg_shape, dtype=gate_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh):
T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H,
threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
...@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg(
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
V_shared: tilelang.layout.make_swizzled_layout(V_shared), {
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared), dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dh_shared: tilelang.layout.make_swizzled_layout(dh_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared), q_shared: tilelang.layout.make_swizzled_layout(q_shared),
}) k_shared: tilelang.layout.make_swizzled_layout(k_shared),
}
)
T.clear(dg_last_local) T.clear(dg_last_local)
T.clear(G_last_local) T.clear(G_last_local)
...@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(dw_fragment) T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared)
V_shared) T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared)
T.copy( T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared)
dO[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dO_shared)
T.copy(
h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], h_shared)
T.copy(
dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], dh_shared)
if use_g: if use_g:
T.clear(dg_last_fragment_scalar) T.clear(dg_last_fragment_scalar)
...@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV): # for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV): for i_kv in T.Parallel(block_DK * block_DV):
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0] dg_last_local[0] += dg_last_fragment_scalar[0]
...@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw: if use_dw:
T.copy( T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared)
dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw: if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy( T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK]) T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared)
T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
q_shared)
T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
k_shared)
T.copy(q_shared, q_fragment) T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment) T.copy(k_shared, k_fragment)
...@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale
bh]) * scale
T.clear(dg_fragment_reduce_tmp) T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
...@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then(): with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh])
G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else(): with T.Else():
dk_fragment[i_s, i_k] = 0 dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp) T.clear(dg_fragment_reduce_tmp)
...@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[1] = dg_last_fragment_scalar_2[0] dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then(): with T.Then():
ds_fragment[i_s1, i_s2] = ds_fragment[ ds_fragment[i_s1, i_s2] = (
i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale
G[bb, bs * block_S + i_s2, bh]) * scale )
with T.Else(): with T.Else():
ds_fragment[i_s1, i_s2] = 0 ds_fragment[i_s1, i_s2] = 0
...@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(ds_fragment_positive_transpose) T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[ ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
...@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117 with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then(): with T.Then():
dg_fragment_final[ dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
...@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy( T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
return kernel return kernel
...@@ -442,32 +412,53 @@ def run_test( ...@@ -442,32 +412,53 @@ def run_test(
threads=256, threads=256,
num_stages=0, num_stages=0,
): ):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, V, h, G, dO, dh, dv, W = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype), block_DK) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
getattr(torch, state_dtype), block_DK) )
# ref # ref
if use_g: if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else: else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang # tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_o_bwd_dqkwg(
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, B,
block_DK, block_DV, threads, num_stages) S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_dw,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g: if use_g:
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError: except ImportError:
...@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype), G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype), A: T.Tensor(output_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
A_shared: tilelang.layout.make_swizzled_layout(A_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}) A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}
)
T.fill(A_fragment, 0) T.fill(A_fragment, 0)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
...@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
...@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
G_diff_local[i_s1, i_s2])
with T.Else(): with T.Else():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
else: else:
...@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel return kernel
...@@ -149,24 +149,21 @@ def run_test( ...@@ -149,24 +149,21 @@ def run_test(
threads, threads,
num_stages, num_stages,
): ):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference # reference
if use_g: if use_g:
A_ref = chunk_scaled_dot_kkt_fwd( A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else: else:
A_ref = chunk_scaled_dot_kkt_fwd( A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang # tilelang
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, kernel = tilelang_chunk_scaled_dot_kkt_fwd(
accum_dtype, use_g, block_S, block_DK, threads, B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
num_stages) )
A_tilelang = kernel(K, Beta, G) A_tilelang = kernel(K, Beta, G)
try: try:
...@@ -192,7 +189,8 @@ def main(): ...@@ -192,7 +189,8 @@ def main():
use_g=True, use_g=True,
block_DK=64, block_DK=64,
threads=128, threads=128,
num_stages=2) num_stages=2,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,6 +10,7 @@ import sys # noqa: F401 ...@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError: except ImportError:
...@@ -20,11 +21,8 @@ import torch ...@@ -20,11 +21,8 @@ import torch
@tilelang.jit( @tilelang.jit(
out_idx=[-1], out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
pass_configs={ )
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_local_cumsum_scalar( def tilelang_chunk_local_cumsum_scalar(
# task config # task config
B, B,
...@@ -42,35 +40,35 @@ def tilelang_chunk_local_cumsum_scalar( ...@@ -42,35 +40,35 @@ def tilelang_chunk_local_cumsum_scalar(
use_fragment=False, use_fragment=False,
): ):
G_shape = (B, H, S) if head_first else (B, S, H) G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S" assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func @T.prim_func
def kernel( def kernel(
G: T.Tensor(G_shape, dtype=input_dtype), G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype), G_new: T.Tensor(G_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first: if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared)
else: else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared)
if use_fragment: if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment) T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse) T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first: if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else: else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
else: else:
T.cumsum(G_shared, dim=1, reverse=reverse) T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first: if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else: else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
return kernel return kernel
...@@ -113,11 +111,8 @@ def run_test( ...@@ -113,11 +111,8 @@ def run_test(
# reference cumsum # reference cumsum
G_new_ref = chunk_local_cumsum_scalar( G_new_ref = chunk_local_cumsum_scalar(
g=G, g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype)
chunk_size=chunk_size, )
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
# tilelang cumsum # tilelang cumsum
block_S = chunk_size block_S = chunk_size
...@@ -162,7 +157,8 @@ def main(): ...@@ -162,7 +157,8 @@ def main():
input_dtype="float32", input_dtype="float32",
output_dtype="float32", output_dtype="float32",
threads=256, threads=256,
use_fragment=False) use_fragment=False,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError: except ImportError:
...@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( ...@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype), A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype), W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype), U: T.Tensor(V_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd( ...@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
V_shared: tilelang.layout.make_swizzled_layout(V_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared), W_shared: tilelang.layout.make_swizzled_layout(W_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), U_shared: tilelang.layout.make_swizzled_layout(U_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
}) U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
}
)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions # First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared) T.copy(U_fragment, U_shared)
T.copy( T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s, W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions # First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared) T.copy(W_fragment, W_shared)
T.copy( T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
return kernel return kernel
...@@ -159,15 +153,8 @@ def run_test( ...@@ -159,15 +153,8 @@ def run_test(
num_stages, num_stages,
): ):
K, V, Beta, G, A = prepare_input( K, V, Beta, G, A = prepare_input(
B, B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
S, )
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
...@@ -191,7 +178,8 @@ def run_test( ...@@ -191,7 +178,8 @@ def run_test(
block_DK=block_DK, block_DK=block_DK,
block_DV=block_DV, block_DV=block_DV,
threads=threads, threads=threads,
num_stages=num_stages) num_stages=num_stages,
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
...@@ -224,7 +212,8 @@ def main(): ...@@ -224,7 +212,8 @@ def main():
block_DK=64, block_DK=64,
block_DV=32, block_DV=32,
threads=128, threads=128,
num_stages=3) num_stages=3,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,6 +10,7 @@ import tilelang.language as T ...@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError: except ImportError:
...@@ -93,10 +94,8 @@ def prepare_output( ...@@ -93,10 +94,8 @@ def prepare_output(
@tilelang.jit( @tilelang.jit(
out_idx=[-5, -4, -3, -2, -1], out_idx=[-5, -4, -3, -2, -1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, )
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd( def tilelang_wy_fast_bwd(
# task config # task config
B, B,
...@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( ...@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype), A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype), dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype), du: T.Tensor(du_shape, dtype=input_dtype),
# output # output
dA: T.Tensor(dA_shape, dtype=input_dtype), dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype), dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), dbeta: T.Tensor(dbeta_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype), dg: T.Tensor(dg_shape, dtype=gate_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd( ...@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T.clear(dbeta_fragment_v) T.clear(dbeta_fragment_v)
T.clear(dg_fragment) T.clear(dg_fragment)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
...@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd( ...@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk # Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s, K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
i_k2] = K_shared[i_s, T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared)
i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(
dw[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[ dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
i_s,
i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ dg_fragment_reduce_tmp[i_s, i_k2] = (
i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
)
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk # correct dk
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dv # Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy( T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared)
du[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
...@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd( ...@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV): # for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s, dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s,
i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy( T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA # Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA # correct dA
T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel return kernel
@tilelang.jit( @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True})
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd_split( def tilelang_wy_fast_bwd_split(
# task config # task config
B, B,
...@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( ...@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype), A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype), dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype), du: T.Tensor(du_shape, dtype=input_dtype),
dA: T.Tensor(dA_shape, dtype=input_dtype), dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype), dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype),
dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype),
dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2) T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
...@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S): # for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA # Update dA
...@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then(): with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh])
G[bb, bs * block_S + i_s2, bh])
with T.Else(): with T.Else():
dA_fragment[i_s1, i_s2] = 0 dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared) T.copy(dA_fragment, dA_shared)
...@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split( ...@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk # Update dk using previous dk
T.clear(A_fragment) T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared)
K_shared)
T.copy(
dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment) T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
...@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split( ...@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s,
i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dg and dbeta # Update dg and dbeta
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
...@@ -460,19 +428,25 @@ def run_test( ...@@ -460,19 +428,25 @@ def run_test(
threads=128, threads=128,
num_stages=0, num_stages=0,
): ):
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, K, V, Beta, G, A, dw, du = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, H,
accum_dtype), getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype)) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
getattr(torch, state_dtype)) )
BS = chunk_size BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
...@@ -480,28 +454,55 @@ def run_test( ...@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref # ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None)
K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang # tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_wy_fast_bwd(
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, B,
num_stages) S,
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( H,
K, V, Beta, G, A, dw, du) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize() torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, kernel_split = tilelang_wy_fast_bwd_split(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
block_DK, block_DV, threads, num_stages) S,
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, H,
dg_tilelang_A_positive, dg_tilelang_A_negative) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize() torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
dim=-1)
from test_utils import assert_similar from test_utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
......
...@@ -25,16 +25,10 @@ num_stages = 1 ...@@ -25,16 +25,10 @@ num_stages = 1
def test_example_wy_fast_compilation(): def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input( K, V, Beta, G, A = prepare_input(
B, B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
S, )
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
# tilelang # tilelang
block_S = chunk_size block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd( kernel = tilelang_recompute_w_u_fwd(
...@@ -52,22 +46,31 @@ def test_example_wy_fast_compilation(): ...@@ -52,22 +46,31 @@ def test_example_wy_fast_compilation():
block_DK=block_DK, block_DK=block_DK,
block_DV=block_DV, block_DV=block_DV,
threads=threads, threads=threads,
num_stages=num_stages) num_stages=num_stages,
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation(): def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), K, V, Beta, G, A, dw, du = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, S,
accum_dtype), getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
getattr(torch, state_dtype)) )
BS = chunk_size BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
...@@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation(): ...@@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang # tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_wy_fast_bwd(
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, B,
num_stages) S,
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( H,
K, V, Beta, G, A, dw, du) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize() torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, kernel_split = tilelang_wy_fast_bwd_split(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
block_DK, block_DV, threads, num_stages) S,
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, H,
dg_tilelang_A_positive, dg_tilelang_A_negative) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize() torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
dim=-1)
def test_example_chunk_o_compilation(): def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype), Q, K, V, HIDDEN, G = prepare_input(
getattr(torch, gate_dtype)) B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
scale = 1.0 / DK**0.5 scale = 1.0 / DK**0.5
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_fwd_o(
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, B,
threads, num_stages) S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation(): def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), Q, K, V, h, G, dO, dh, dv, W = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, DV,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, chunk_size,
block_DK, block_DV, threads, num_stages) getattr(torch, input_dtype),
getattr(torch, output_dtype),
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, getattr(torch, accum_dtype),
W) # noqa: F841 getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
True,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841
if use_g: if use_g:
dg_tilelang = dg_tilelang.sum(dim=0) dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation(): def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype)) K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, kernel = tilelang_chunk_scaled_dot_kkt_fwd(
accum_dtype, use_g, block_S, block_DK, threads, B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
num_stages) )
A_tilelang = kernel(K, Beta, G) # noqa: F841 A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation(): def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size block_S = chunk_size
...@@ -158,33 +239,79 @@ def test_example_cumsum_compilation(): ...@@ -158,33 +239,79 @@ def test_example_cumsum_compilation():
def test_example_chunk_delta_h_compilation(): def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), K, W, U, G, initial_state = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype)) H,
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, DK,
accum_dtype, gate_dtype, state_dtype, chunk_size, DV,
use_g, use_initial_state, store_final_state, chunk_size,
save_new_value, block_DK, block_DV, threads, getattr(torch, input_dtype),
num_stages) getattr(torch, output_dtype),
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, getattr(torch, accum_dtype),
initial_state) # noqa: F841 getattr(torch, gate_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
block_DK,
block_DV,
threads,
num_stages,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation(): def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), Q, K, W, G, h0, dht, dO, dv = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, DV,
accum_dtype, gate_dtype, state_dtype, chunk_size,
chunk_size, 1.0, use_g, use_initial_state, getattr(torch, input_dtype),
use_final_state_gradient, block_DV, threads, getattr(torch, output_dtype),
num_stages) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
......
...@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): ...@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double() x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
if denominator == 0: if denominator == 0:
print_red_warning(f'{name} all zero') print_red_warning(f"{name} all zero")
return 1 return 1
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return sim return sim
...@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): ...@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x) x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y) y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask): if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch') print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
if not torch.isclose( if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, print_red_warning(f"{name} Error: nonfinite value mismatch")
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
x = x.masked_fill(~x_mask, 0) x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0) y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name) sim = calc_sim(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}') print_red_warning(f"{name} Error: {diff}")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment