Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -47,7 +47,7 @@ def attention_ref(
if upcast:
q, k, v = q.float(), k.float(), v.float()
dim = q.shape[-1]
scale = (1.0 / dim)**0.5 # log2(e)
scale = (1.0 / dim) ** 0.5 # log2(e)
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
scores = torch.einsum("bthd,bshd->bhts", q, k)
......@@ -68,20 +68,13 @@ def attention_ref(
@tilelang.jit(
out_idx=[6], pass_configs={
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch_size,
UQ,
UKV,
heads,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=32):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
},
)
def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64, num_stages=0, threads=32):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [UQ, heads, dim]
k_shape = [UKV, heads, dim]
v_shape = [UKV, heads, dim]
......@@ -100,9 +93,7 @@ def flashattn(batch_size,
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
with T.Kernel(
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype, "shared")
K_shared = T.alloc_shared([block_N, dim], dtype, "shared")
V_shared = T.alloc_shared([block_N, dim], dtype, "shared")
......@@ -151,15 +142,17 @@ def flashattn(batch_size,
K_shared[i, d] = 0
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
(bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= k * block_N + j)
and (bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype),
0,
)
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
k * block_N + j >= k_current_seqlen),
-T.infinity(acc_s.dtype), 0)
acc_s[i, j] = T.if_then_else(
(bx * block_M + i >= q_current_seqlen or k * block_N + j >= k_current_seqlen), -T.infinity(acc_s.dtype), 0
)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -244,8 +237,7 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
UQ = q_unpad.shape[0] # unpadded query length
UK = k_unpad.shape[0] # unpadded key length
......@@ -287,10 +279,10 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_len', type=int, default=2048, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_len", type=int, default=2048, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim)
......@@ -62,14 +62,12 @@ def test_example_mha_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_fwd_bshd_wgmma_pipelined():
example_gqa_fwd_bshd_wgmma_pipelined.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
def test_example_gqa_fwd_bshd():
example_gqa_fwd_bshd.main(
batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False)
@tilelang.testing.requires_cuda
......
......@@ -9,22 +9,14 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
return padding_mask
def generate_qkv(q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False):
def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
......@@ -39,15 +31,12 @@ def generate_qkv(q,
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q
)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size)
output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
......@@ -55,8 +44,7 @@ def generate_qkv(q,
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device)
max_seqlen_k = seqlen_k
if qkvpacked:
......@@ -67,8 +55,7 @@ def generate_qkv(q,
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
......@@ -84,8 +71,7 @@ def generate_qkv(q,
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
......
......@@ -20,13 +20,7 @@ def get_configs():
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
......@@ -48,17 +42,13 @@ def get_heuristic_config() -> Tuple[Dict, int]:
# TODO(lei): fix warp specialized and tma lower pass
def get_pass_configs():
return {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
}
return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages,
threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim]
......@@ -98,20 +88,19 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -127,14 +116,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.macro
def flash_attn_split(
......@@ -165,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -174,19 +163,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], K_shared)
K[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
K_shared,
)
T.copy(
mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head], mask_local)
mask[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
],
mask_local,
)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i,
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -203,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], V_shared)
V[
bid,
(seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
V_shared,
)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
......@@ -216,8 +217,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :])
T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :])
@T.macro
def combine(
......@@ -233,12 +233,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_fragment([128], accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
})
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -305,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim = query.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
if mask is not None:
mask = rearrange(mask, 'b s h -> b h s')
mask = rearrange(mask, "b s h -> b h s")
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -339,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv = K.size(1)
num_head_groups = nheads // groups
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N),
device="cuda",
dtype=torch.float16)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups),
device="cuda",
dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
......@@ -356,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups)
Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bghd,bkhd->bghk', Q_,
K[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
acc_s = torch.einsum(
"bghd,bkhd->bghk",
Q_,
K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
if mask is not None:
mask_local = mask[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
mask_local = rearrange(mask_local, 'b s h -> b h s')
mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
mask_local = rearrange(mask_local, "b s h -> b h s")
mask_local = mask_local.unsqueeze(1)
acc_s = acc_s.masked_fill(mask_local == 0, float('-inf'))
acc_s = acc_s.masked_fill(mask_local == 0, float("-inf"))
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
......@@ -382,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
'bghk,bkhd->bghd', acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
"bghk,bkhd->bghd",
acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d')
logsum_out = rearrange(logsum, 'b g h->b (h g)')
acc_o_out = rearrange(acc_o, "b g h d->b (h g) d")
logsum_out = rearrange(logsum, "b g h->b (h g)")
acc_o_out /= logsum_out[:, :, None]
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)')
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)")
gacc_o[ks, :, :, :] = acc_o_out
glogsum[ks, :, :] = logsum_out
......@@ -426,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
......@@ -434,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name)
diff = 1. - sim
diff = 1.0 - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
print_red_warning(f"{name} Error: {diff}")
if assert_:
raise AssertionError(f'{name} Error: {diff}')
raise AssertionError(f"{name} Error: {diff}")
else:
if print_:
print(f'passed: {name} diff={diff}')
print(f"passed: {name} diff={diff}")
def main(batch: int = 1,
heads: int = 32,
groups: int = 8,
kv_seqlen: int = 8192,
dim: int = 128,
tune: bool = False):
def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False):
batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops
if (not tune):
if not tune:
config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
......@@ -497,11 +485,11 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
......@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
......@@ -74,14 +73,9 @@ def _fwd_inner(
return m_i, l_i, acc
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [4, 8]\
for num_stages in [2, 4]\
],
key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'],
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]],
key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"],
)
@triton.jit
def _fwd_kernel_varlen(
......@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od,
stride_sb,
stride_sh,
stride_sn, #bmask shape [b, q_h, seq/BLOCK_N]
stride_sn, # bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_z = tl.program_id(0)
off_h_for_kv = tl.program_id(1)
off_h_q = off_h_for_kv * gqa_group_size
......@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh
mask_h = offs_h < gqa_group_size
q = tl.load(
Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
if s_aux is not None:
sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32)
......@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc = acc.to(O.dtype.element_ty)
tl.store(
O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od,
acc,
mask=mask_h[:, None])
tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None])
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
......@@ -204,31 +194,16 @@ def get_configs():
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch,
heads,
k_heads,
max_seqlen_kv,
total_seqlen_k,
dim,
has_sink,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
def flashattn(
batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
......@@ -268,13 +243,15 @@ def flashattn(batch,
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], "float32")
T.annotate_layout({
T.annotate_layout(
{
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
}
)
bid = bx
hid = by
......@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :],
K_shared)
T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j],
-T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -320,12 +295,11 @@ def flashattn(batch,
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :],
V_shared)
T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared)
T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
......@@ -338,10 +312,9 @@ def flashattn(batch,
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
# T.copy(S_fragment, S_shared)
T.copy(S_shared[:valid_block_H, :], S[bid,
hid * valid_block_H:(hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
......@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)),
dtype=Q.dtype,
device=Q.device)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
if use_per_kv_head_sparse_index:
......@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H = 64
O = torch.zeros_like(Q)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)),
dtype=Q.dtype,
device=Q.device)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device)
def grid(META):
return (batch, k_h)
......@@ -483,15 +452,15 @@ def test_equal_seqlen_decode_main(args):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
......@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
......@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
......@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
......@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
......@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
......@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat).squeeze(2) # [batch, q_heads, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16)
ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
......@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
......@@ -616,7 +575,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
......@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
......@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
......@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
......@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
......@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel=tl_kernel,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
......@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
......@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
......@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf')
attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
......@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf')
logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
......@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat) # [b, q_heads, 1, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
......@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
......@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
print("✅ All tests passed!")
......@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
......@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
......@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
......@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
# Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
......@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen,
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink,
block_size)
triton_time = do_bench(
flash_attn_with_attn_pool_decode,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads')
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads')
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length')
parser.add_argument(
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension')
parser.add_argument('--block_size', type=int, default=64, help='Block size for computation')
parser.add_argument(
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type')
parser.add_argument(
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths')
parser.add_argument(
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = False
args.dtype = 'float16'
args.dtype = "float16"
args.num_split = 1
if args.benchmark:
......
......@@ -10,6 +10,7 @@ torch.manual_seed(0)
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
......@@ -17,19 +18,14 @@ def get_configs():
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch,
def flashattn(
batch,
heads,
k_heads,
max_seqlen_kv,
......@@ -41,8 +37,9 @@ def flashattn(batch,
block_H=64,
num_split=1,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
threads=128,
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
......@@ -51,7 +48,9 @@ def flashattn(batch,
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // k_heads
assert page_block_size >= block_N and page_block_size % block_N == 0, "page_block_size must be larger than block_N and a multiple of block_N"
assert page_block_size >= block_N and page_block_size % block_N == 0, (
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case
......@@ -91,7 +90,7 @@ def flashattn(batch,
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -99,15 +98,12 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (
k * block_N) % page_block_size
T.copy(K[cur_start_k + k_start:cur_start_k + k_start + block_N, cur_kv_head, :],
K_shared)
k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j],
-T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -127,14 +123,12 @@ def flashattn(batch,
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (
k * block_N) % page_block_size
T.copy(V[cur_start_k + v_start:cur_start_k + v_start + block_N, cur_kv_head, :],
V_shared)
v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared)
T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
......@@ -144,9 +138,8 @@ def flashattn(batch,
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid,
hid * valid_block_H:(hid + 1) * valid_block_H, :])
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
......@@ -195,9 +188,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)),
dtype=Q.dtype,
device=Q.device)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table)
if use_per_kv_head_sparse_index:
......@@ -223,15 +214,15 @@ def test_equal_seqlen_decode_main(args):
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype)
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
......@@ -239,8 +230,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
......@@ -250,11 +240,9 @@ def test_equal_seqlen_decode_main(args):
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink, page_block_size)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
......@@ -274,7 +262,8 @@ def test_equal_seqlen_decode_main(args):
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
......@@ -290,9 +279,7 @@ def test_equal_seqlen_decode_main(args):
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
......@@ -301,14 +288,12 @@ def test_equal_seqlen_decode_main(args):
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
......@@ -317,15 +302,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat).squeeze(2) # [batch, q_heads, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16)
ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
......@@ -339,15 +324,10 @@ def test_equal_seqlen_decode_main(args):
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
......@@ -368,7 +348,7 @@ def test_varlen_decode_main(args):
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
......@@ -376,7 +356,7 @@ def test_varlen_decode_main(args):
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
......@@ -386,9 +366,9 @@ def test_varlen_decode_main(args):
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
......@@ -401,11 +381,9 @@ def test_varlen_decode_main(args):
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink, page_block_size)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
......@@ -425,7 +403,8 @@ def test_varlen_decode_main(args):
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size)
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
......@@ -441,9 +420,7 @@ def test_varlen_decode_main(args):
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :,
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
......@@ -457,8 +434,8 @@ def test_varlen_decode_main(args):
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype)
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
......@@ -467,10 +444,8 @@ def test_varlen_decode_main(args):
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
......@@ -480,20 +455,17 @@ def test_varlen_decode_main(args):
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf')
attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
......@@ -506,13 +478,12 @@ def test_varlen_decode_main(args):
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf')
logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
......@@ -528,8 +499,7 @@ def test_varlen_decode_main(args):
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype),
v_repeat) # [b, q_heads, 1, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
......@@ -538,7 +508,8 @@ def test_varlen_decode_main(args):
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
......@@ -554,22 +525,16 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
print("✅ All tests passed!")
......@@ -605,7 +570,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32)
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
......@@ -613,9 +578,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype)
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
......@@ -623,7 +588,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
......@@ -636,11 +601,9 @@ def speed_benchmark_decode_comparison(args):
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size,
args.test_sink, page_block_size)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(
batch, math.ceil(real_max_k_seqlen / page_block_size), device='cuda', dtype=torch.int32)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
......@@ -671,36 +634,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen,
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink,
block_size)
triton_time = do_bench(
flash_attn_with_attn_pool_decode,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads')
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads')
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length')
parser.add_argument(
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension')
parser.add_argument('--block_size', type=int, default=128, help='Block size for computation')
parser.add_argument(
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type')
parser.add_argument(
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths')
parser.add_argument(
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
parser.add_argument('--page_block_size', type=int, default=128, help='Page block size')
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument("--page_block_size", type=int, default=128, help="Page block size")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = True
args.dtype = 'float16'
args.dtype = "float16"
args.num_split = 1
if args.benchmark:
......
......@@ -10,7 +10,7 @@ num_split = 4
@tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim]
......@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32,
sid: T.int32,
):
T.copy(
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared)
T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared)
# TODO: Handle causal split case
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -52,9 +49,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32,
sid: T.int32,
):
T.copy(
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], V_shared)
T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......@@ -105,9 +100,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
threads=128) as (bx, by, bz):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -128,33 +121,30 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle causal split case
loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
(mid + 1) * block_M, block_N)) if is_causal else T.ceildiv(
(seqlen_kv // num_split), block_N))
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N))
if is_causal
else T.ceildiv((seqlen_kv // num_split), block_N)
)
for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M])
T.copy(acc_o, O_shared)
T.copy(
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True)
@T.macro
def combine(
......@@ -173,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
o_shared: tilelang.layout.make_swizzled_layout(o_shared),
po_shared: tilelang.layout.make_swizzled_layout(po_shared),
})
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
T.copy(glse[
T.copy(
glse[
bz,
by,
:,
bx * block_M:(bx + 1) * block_M,
], lse_local)
bx * block_M : (bx + 1) * block_M,
],
lse_local,
)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split)
......@@ -195,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2):
T.copy(
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True)
T.copy(po_shared, po_local)
for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i]
......@@ -207,7 +199,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True)
T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func
def flashattn_mha_inference(
......@@ -227,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def ref_program(Q, K, V, glse, Output_partial, causal):
assert causal is False
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
......@@ -258,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N = 128
seqlen_kv = K.size(1)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
......@@ -275,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_,
K[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N]
acc_s = torch.einsum(
"bqhd,bkhd->bhqk",
Q_,
K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max)
......@@ -290,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum(
'bhqk,bkhd->bqhd', acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
"bhqk,bkhd->bqhd",
acc_s_cast,
V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2)
......@@ -300,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0,
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
......
......@@ -9,7 +9,8 @@ from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden,
def moe_forward_tilelang_shared(
d_hidden,
d_expert,
n_shared_experts,
dtype,
......@@ -18,8 +19,8 @@ def moe_forward_tilelang_shared(d_hidden,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1):
num_stages=1,
):
scale = 1.44269504 # log2(e)
# Parameters
......@@ -44,9 +45,7 @@ def moe_forward_tilelang_shared(d_hidden,
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(
T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
# Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
......@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits
with T.Kernel(
T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
......@@ -98,7 +94,8 @@ def moe_forward_tilelang_shared(d_hidden,
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(d_hidden,
def moe_forward_tilelang_routed(
d_hidden,
d_expert,
n_routed_experts,
dtype,
......@@ -110,8 +107,8 @@ def moe_forward_tilelang_routed(d_hidden,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=None):
coalesced_width=None,
):
scale = 1.44269504 # log2(e)
# Parameters
......@@ -132,8 +129,8 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = (group_sum)
group_sizes_shape = (n_routed_experts)
routed_expert_weights_shape = group_sum
group_sizes_shape = n_routed_experts
@T.prim_func
def kernel(
......@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[
cur_group_idx[0]]
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local)
T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(
input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden],
input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared,
coalesced_width=coalesced_width)
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_gate[cur_group_idx[0],
by * block_dexpert:(by + 1) * block_dexpert,
k * block_dhidden:(k + 1) * block_dhidden],
routed_expert_gate_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_gate[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_gate_shared,
gate_logits_local,
k_pack=k_pack,
transpose_B=True)
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True)
T.copy(
routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert,
k * block_dhidden:(k + 1) * block_dhidden],
routed_expert_up[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_up_shared,
up_logits_local,
k_pack=k_pack,
transpose_B=True)
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert):
......@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[
cur_group_idx[0]]
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(
up_logits[m_start:m_start + block_token,
k * block_dexpert:(k + 1) * block_dexpert],
up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
up_logits_shared,
coalesced_width=coalesced_width)
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_down[cur_group_idx[0],
by * block_dhidden:(by + 1) * block_dhidden,
k * block_dexpert:(k + 1) * block_dexpert],
routed_expert_down_shared,
coalesced_width=coalesced_width)
T.gemm(
up_logits_shared,
routed_expert_down[
cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
],
routed_expert_down_shared,
output_local,
k_pack=k_pack,
transpose_B=True)
coalesced_width=coalesced_width,
)
T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows:
output[m_start + i, by * block_dhidden +
j] = output_local[i, j] * routed_expert_weights[m_start + i]
output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel
class Expert(nn.Module):
def __init__(self,
config: Dict,
gate: torch.Tensor,
up: torch.Tensor,
down: torch.Tensor,
d_expert: Optional[int] = None):
def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
......@@ -294,14 +265,13 @@ class Expert(nn.Module):
class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights['router.weight'].t()
self.W_g_weight = weights["router.weight"].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight
......@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class MoE(nn.Module):
def __init__(self,
config: Dict,
shared_kernel: tilelang.JITKernel,
routed_kernel: tilelang.JITKernel,
weights: Dict,
padding_M: int = 128):
def __init__(
self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128
):
super().__init__()
self.config = config
self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel
self.padding_M = padding_M
self.experts = nn.ModuleList([
self.experts = nn.ModuleList(
[
Expert(
config,
gate=weights[f'experts.{i}.0.weight'],
up=weights[f'experts.{i}.1.weight'],
down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"])
])
gate=weights[f"experts.{i}.0.weight"],
up=weights[f"experts.{i}.1.weight"],
down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert(
config=config,
gate=weights['shared_experts.0.weight'],
up=weights['shared_experts.1.weight'],
down=weights['shared_experts.2.weight'],
d_expert=shared_expert_dim).to(self.device)
gate=weights["shared_experts.0.weight"],
up=weights["shared_experts.1.weight"],
down=weights["shared_experts.2.weight"],
d_expert=shared_expert_dim,
).to(self.device)
self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]),
dtype=torch.float16,
device=self.device)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts],
dim=0)
(config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_hidden"]),
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
device=self.device,
)
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.float16,
device=self.device)
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.int64,
device=self.device)
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
)
self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device)
(config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device
)
self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_expert"]),
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device)
device=self.device,
)
self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_hidden"]),
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
device=self.device,
)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -413,22 +376,20 @@ class MoE(nn.Module):
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[
idxs[start_idx:end_idx]]
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor(
tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil(
(counts[i - 1] + 1) / self.padding_M) * self.padding_M
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128
M = math.ceil(
self.config["batch_size"] * self.config["seq_len"] *
self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"]
M = (
math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
+ self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M):
......@@ -437,8 +398,7 @@ class MoE(nn.Module):
if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor(
group_padded_offsets, dtype=torch.int32, device=self.device)
group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution
......@@ -448,11 +408,19 @@ class MoE(nn.Module):
with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM
self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate,
self.stacked_expert_w_up, self.stacked_expert_w_down,
self.stacked_expert_weights, group_sizes, group_offset,
group_padded_offsets, group_idx_for_bx, self.up_logits_routed,
self.expert_output_routed)
self.routed_kernel(
self.stacked_expert_tokens,
self.stacked_expert_w_gate,
self.stacked_expert_w_up,
self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# Scatter reduce
self.expert_cache = torch.scatter_reduce(
......@@ -460,14 +428,19 @@ class MoE(nn.Module):
0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed,
reduce='sum')
reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream):
self.shared_kernel(x_flat, self.shared_expert.W_gate_weight,
self.shared_expert.W_up_weight, self.shared_expert.W_down_weight,
self.up_logits_shared, self.expert_output_shared)
self.shared_kernel(
x_flat,
self.shared_expert.W_gate_weight,
self.shared_expert.W_up_weight,
self.shared_expert.W_down_weight,
self.up_logits_shared,
self.expert_output_shared,
)
shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize()
......@@ -498,7 +471,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
config["d_expert"],
config["n_shared_experts"],
dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"])
num_tokens=config["batch_size"] * config["seq_len"],
)
routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"],
config["d_expert"],
......@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=2)
coalesced_width=2,
)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
......@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return output
def main(d_hidden=7168,
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192):
config = {
"dhidden": d_hidden,
"dexpert": d_expert,
......@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394
"seed": 81394,
}
data = generate_input(**config)
......
......@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__()
self.config = config
......@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class MoEGateTorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
......@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class MoETorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.experts = nn.ModuleList(
[ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
......@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return routed_output + shared_output
@torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor,
flat_expert_weights: torch.Tensor) -> torch.Tensor:
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
......@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_(
0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
return expert_cache
......@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe = MoETorch(config)
# Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight'])
moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
for i in range(num_experts):
gate_proj_weight = weights[f'experts.{i}.0.weight']
up_proj_weight = weights[f'experts.{i}.1.weight']
down_proj_weight = weights[f'experts.{i}.2.weight']
gate_proj_weight = weights[f"experts.{i}.0.weight"]
up_proj_weight = weights[f"experts.{i}.1.weight"]
down_proj_weight = weights[f"experts.{i}.2.weight"]
# Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
output = moe(input_tensor)
......@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code
def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int,
nexpertspertoken: int, bs: int, seqlen: int,
seed: int) -> Tuple[torch.Tensor, Dict, Dict]:
def generate_input(
dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden
d_expert = dexpert
......@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len": seq_len,
}
gen = torch.Generator(device='cuda')
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
num_experts = n_routed_experts
expert_dim = d_expert
weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen).contiguous()
input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
# Initialize router weights
weights['router.weight'] = torch.randn(
(num_experts, d_hidden), device="cuda", dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts):
weights[f'experts.{i}.0.weight'] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim)
weights[f'experts.{i}.1.weight'] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim)
weights[f'experts.{i}.2.weight'] = torch.randn(
(expert_dim, d_hidden), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
weights['shared_experts.0.weight'] = torch.randn(
(d_hidden, expert_dim * n_shared_experts),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.1.weight'] = torch.randn(
(d_hidden, expert_dim * n_shared_experts),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
weights[f"experts.{i}.0.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.1.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.2.weight"] = torch.randn(
(expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
weights["shared_experts.0.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.1.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.2.weight"] = torch.randn(
(expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
return (input_tensor, weights, config)
......
......@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main(
d_hidden=1024,
d_expert=256,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)
d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024
)
if __name__ == "__main__":
......
......@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
......@@ -49,6 +50,7 @@ def prepare_input(
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
......@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV = dv.shape[-1]
block_S = 64
BS = S // block_S
dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty(
(B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype)
dh, dh0, dv2 = (
torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
torch.empty((B, H, DK, DV), dtype=state_dtype),
torch.empty((B, S, H, DV), dtype=output_dtype),
)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
......@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3),
dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g:
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2,
i_h] <= 0:
dv_tmp[i_b, i_s2,
i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] -
G[i_b, i_s * block_S + i_s2, i_h])
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
else:
dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp
dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :]
Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
for i_s2 in range(block_S):
for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale
W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :]
W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
......@@ -269,7 +270,8 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T.use_swizzle(10)
T.annotate_layout({
T.annotate_layout(
{
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
......@@ -279,10 +281,11 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
})
}
)
if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared)
T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment)
else:
T.clear(b_dh_fragment)
......@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh
T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Update dv
T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g:
T.copy(
G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh],
G_shared,
disable_tma=True)
T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0])
......@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then():
dv_fragment[i_s2,
i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else():
dv_fragment[i_s2, i_v] = 0
T.copy(
dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dv_shared)
T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv
T.copy(dv_fragment, dv_shared)
T.copy(
dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
# Update dh
T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment)
if use_g:
......@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy(
dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dO_shared)
T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
......@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel
......@@ -444,44 +437,61 @@ def run_test(
num_stages=0,
use_torch=False,
):
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
Q, K, W, G, h0, dht, dO, dv = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
getattr(torch, state_dtype),
)
dh_ref, dh0_ref, dv2_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
# fla ref
print("fla running...", flush=True)
if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
else:
G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
# tilelang
print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, scale, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
# kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench(
chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms")
......@@ -496,19 +506,47 @@ def run_test(
print("torch running...", flush=True)
if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
Q,
K,
W,
G,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
Q,
K,
W,
None,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
......
......@@ -10,6 +10,7 @@ from tilelang.autotuner import autotune
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
......@@ -56,6 +57,7 @@ def prepare_input(
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
......@@ -83,18 +85,14 @@ def prepare_output(
def get_configs():
import itertools
block_DK = [32, 64, 128]
block_DV = [32, 64, 128]
threads = [128, 256]
num_stages = [1, 2, 3]
_configs = list(itertools.product(block_DK, block_DV, threads, num_stages))
configs = [{
'block_DK': c[0],
'block_DV': c[1],
'threads': c[2],
'num_stages': c[3]
} for c in _configs]
configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs]
return configs
......@@ -162,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({
T.annotate_layout(
{
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared),
})
}
)
T.use_swizzle(10)
if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared)
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment)
else:
T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared)
T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S
T.copy(
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared)
T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
......@@ -198,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new
if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared)
T.copy(
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared)
T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g
if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
......@@ -213,7 +209,8 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
(G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695)
(G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695
)
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
......@@ -228,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state
if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel
......@@ -279,17 +276,24 @@ def run_test(
threads=128,
num_stages=0,
):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
K, W, U, G, initial_state = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
getattr(torch, gate_dtype),
)
h_ref, final_state_ref, V_new_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
......@@ -300,13 +304,27 @@ def run_test(
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
save_new_value=save_new_value,
)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value)
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
......@@ -320,19 +338,15 @@ def run_test(
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
save_new_value=save_new_value,
)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
try:
h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar(
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗")
......@@ -346,7 +360,8 @@ def run_test(
final_state_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False)
raise_assert=False,
)
print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗")
......@@ -355,12 +370,7 @@ def run_test(
try:
V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar(
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗")
......
......@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError:
......@@ -94,9 +95,7 @@ def tilelang_chunk_fwd_o(
G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
......@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({
T.annotate_layout(
{
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
}
)
T.clear(A_fragment)
T.clear(O_fragment)
T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
Q_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
......@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
......@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
V_shared)
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared)
T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment)
......@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
return kernel
......@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch,
output_dtype_torch, accum_dtype_torch, gate_dtype_torch)
Q, K, V, HIDDEN, G = prepare_input(
B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch
)
scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
......@@ -200,9 +193,25 @@ def run_test(
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
kernel = tilelang_chunk_fwd_o(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G)
try:
......
......@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError:
......@@ -108,10 +109,8 @@ def prepare_output(
@tilelang.jit(
out_idx=[-4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def tilelang_chunk_o_bwd_dqkwg(
# task config
B,
......@@ -171,9 +170,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(
T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H,
threads=threads) as (bk, bs, bbh):
with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
......@@ -212,7 +209,8 @@ def tilelang_chunk_o_bwd_dqkwg(
T.use_swizzle(10)
T.annotate_layout({
T.annotate_layout(
{
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
......@@ -220,7 +218,8 @@ def tilelang_chunk_o_bwd_dqkwg(
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
})
}
)
T.clear(dg_last_local)
T.clear(G_last_local)
......@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
T.copy(
dO[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dO_shared)
T.copy(
h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], h_shared)
T.copy(
dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], dh_shared)
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared)
T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared)
T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared)
if use_g:
T.clear(dg_last_fragment_scalar)
......@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
......@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw:
T.copy(
dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dv_shared)
T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy(
dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
q_shared)
T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
k_shared)
T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared)
T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment)
......@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s,
bh]) * scale
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
......@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(
G_last_local[0] - G[bb, bs * block_S + i_s, bh])
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else():
dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp)
......@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
ds_fragment[i_s1, i_s2] = ds_fragment[
i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh]) * scale
ds_fragment[i_s1, i_s2] = (
ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale
)
with T.Else():
ds_fragment[i_s1, i_s2] = 0
......@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[
i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
......@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then():
dg_fragment_final[
i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
......@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
return kernel
......@@ -442,32 +412,53 @@ def run_test(
threads=256,
num_stages=0,
):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
Q, K, V, h, G, dO, dh, dv, W = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
getattr(torch, state_dtype),
)
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
# ref
if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
block_DK, block_DV, threads, num_stages)
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_dw,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g:
......
......@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
......@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
}
)
T.fill(A_fragment, 0)
T.disable_warp_group_reg_alloc()
......@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
......@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
else:
......@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :])
T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel
......@@ -149,24 +149,21 @@ def run_test(
threads,
num_stages,
):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference
if use_g:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
kernel = tilelang_chunk_scaled_dot_kkt_fwd(
B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
)
A_tilelang = kernel(K, Beta, G)
try:
......@@ -192,7 +189,8 @@ def main():
use_g=True,
block_DK=64,
threads=128,
num_stages=2)
num_stages=2,
)
if __name__ == "__main__":
......
......@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
......@@ -20,11 +21,8 @@ import torch
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
)
def tilelang_chunk_local_cumsum_scalar(
# task config
B,
......@@ -42,7 +40,7 @@ def tilelang_chunk_local_cumsum_scalar(
use_fragment=False,
):
G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func
......@@ -54,23 +52,23 @@ def tilelang_chunk_local_cumsum_scalar(
bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared)
T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared)
else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared)
if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
else:
T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
return kernel
......@@ -113,11 +111,8 @@ def run_test(
# reference cumsum
G_new_ref = chunk_local_cumsum_scalar(
g=G,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype)
)
# tilelang cumsum
block_S = chunk_size
......@@ -162,7 +157,8 @@ def main():
input_dtype="float32",
output_dtype="float32",
threads=256,
use_fragment=False)
use_fragment=False,
)
if __name__ == "__main__":
......
......@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError:
......@@ -95,7 +96,8 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
......@@ -103,41 +105,33 @@ def tilelang_recompute_w_u_fwd(
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})
}
)
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared)
T.copy(
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s,
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared)
T.copy(
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
return kernel
......@@ -159,15 +153,8 @@ def run_test(
num_stages,
):
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
)
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
......@@ -191,7 +178,8 @@ def run_test(
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
num_stages=num_stages,
)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
......@@ -224,7 +212,8 @@ def main():
block_DK=64,
block_DV=32,
threads=128,
num_stages=3)
num_stages=3,
)
if __name__ == "__main__":
......
......@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError:
......@@ -93,10 +94,8 @@ def prepare_output(
@tilelang.jit(
out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def tilelang_wy_fast_bwd(
# task config
B,
......@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T.clear(dbeta_fragment_v)
T.clear(dg_fragment)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
......@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s,
i_k2] = K_shared[i_s,
i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(
dw[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dw_shared)
K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[
i_s,
i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[
i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[
i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
dg_fragment_reduce_tmp[i_s, i_k2] = (
dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
)
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
# Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy(
du[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], du_shared)
T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
......@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s,
i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s,
i_v2]
dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy(
dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA
T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True})
def tilelang_wy_fast_bwd_split(
# task config
B,
......@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
......@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared)
T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
......@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh])
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh])
with T.Else():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
......@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk
T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dk_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
......@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s,
i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s,
i_k2]
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
# Update dg and dbeta
T.copy(A_fragment, A_shared)
......@@ -460,19 +428,25 @@ def run_test(
threads=128,
num_stages=0,
):
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
K, V, Beta, G, A, dw, du = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
getattr(torch, state_dtype),
)
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
......@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(
K, V, G, Beta, A, dw, du, cu_seqlens=None)
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
kernel = tilelang_wy_fast_bwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
kernel_split = tilelang_wy_fast_bwd_split(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
from test_utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
......
......@@ -25,16 +25,10 @@ num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
......@@ -52,22 +46,31 @@ def test_example_wy_fast_compilation():
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
num_stages=num_stages,
)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
K, V, Beta, G, A, dw, du = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
......@@ -75,68 +78,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
kernel = tilelang_wy_fast_bwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
kernel_split = tilelang_wy_fast_bwd_split(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
Q, K, V, HIDDEN, G = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
scale = 1.0 / DK**0.5
block_S = chunk_size
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
kernel = tilelang_chunk_fwd_o(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
Q, K, V, h, G, dO, dh, dv, W = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
True,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
kernel = tilelang_chunk_scaled_dot_kkt_fwd(
B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
)
A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size
......@@ -158,33 +239,79 @@ def test_example_cumsum_compilation():
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
K, W, U, G, initial_state = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G,
initial_state) # noqa: F841
getattr(torch, gate_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
block_DK,
block_DV,
threads,
num_stages,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
Q, K, W, G, h0, dht, dO, dv = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
......
......@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
......@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1. - sim
diff = 1.0 - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
print_red_warning(f"{name} Error: {diff}")
if raise_assert:
raise AssertionError
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment